diff src/rpcify.sml @ 954:2a50da66ffd8

Basic tail recursion introduction seems to be working
author Adam Chlipala <adamc@hcoop.net>
date Thu, 17 Sep 2009 16:35:11 -0400
parents ed06e25c70ef
children 01a4d936395a
line wrap: on
line diff
--- a/src/rpcify.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/rpcify.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -32,6 +32,12 @@
 structure U = CoreUtil
 structure E = CoreEnv
 
+fun multiLiftExpInExp n e =
+    if n = 0 then
+        e
+    else
+        multiLiftExpInExp (n - 1) (E.liftExpInExp 0 e)
+
 structure IS = IntBinarySet
 structure IM = IntBinaryMap
 
@@ -42,7 +48,10 @@
 
 type state = {
      exported : IS.set,
-     export_decls : decl list
+     export_decls : decl list,
+
+     cpsed : exp' IM.map,
+     rpc : IS.set
 }
 
 fun frob file =
@@ -115,7 +124,9 @@
                                          (DExport (Rpc ReadWrite, n), loc) :: #export_decls st)
 
                                 val st = {exported = exported,
-                                          export_decls = export_decls}
+                                          export_decls = export_decls,
+                                          cpsed = #cpsed st,
+                                          rpc = #rpc st}
 
                                 val k = (ECApp ((EFfi ("Basis", "return"), loc),
                                                 (CFfi ("Basis", "transaction"), loc)), loc)
@@ -134,6 +145,11 @@
                     else
                         (e, st)
 
+                  | ENamed n =>
+                    (case IM.find (#cpsed st, n) of
+                         NONE => (e, st)
+                       | SOME re => (re, st))
+
                   | _ => (e, st)
             end
 
@@ -143,6 +159,165 @@
 
         fun decl (d, st : state) =
             let
+                val makesServerCall = U.Exp.exists {kind = fn _ => false,
+                                                    con = fn _ => false,
+                                                    exp = fn EFfi ("Basis", "rpc") => true
+                                                           | ENamed n => IS.member (#rpc st, n)
+                                                           | _ => false}
+
+                val (d, st) =
+                    case #1 d of
+                        DValRec vis =>
+                        if List.exists (fn (_, _, _, e, _) => makesServerCall e) vis then
+                            let
+                                val all = foldl (fn ((_, n, _, _, _), all) => IS.add (all, n)) IS.empty vis
+
+                                val usesRec = U.Exp.exists {kind = fn _ => false,
+                                                            con = fn _ => false,
+                                                            exp = fn ENamed n => IS.member (all, n)
+                                                                   | _ => false}
+
+                                val noRec = not o usesRec
+
+                                fun tailOnly (e, _) =
+                                    case e of
+                                        EPrim _ => true
+                                      | ERel _ => true
+                                      | ENamed _ => true
+                                      | ECon (_, _, _, SOME e) => noRec e
+                                      | ECon _ => true
+                                      | EFfi _ => true
+                                      | EFfiApp (_, _, es) => List.all noRec es
+                                      | EApp (e1, e2) => noRec e2 andalso tailOnly e1
+                                      | EAbs (_, _, _, e) => noRec e
+                                      | ECApp (e1, _) => tailOnly e1
+                                      | ECAbs (_, _, e) => noRec e
+
+                                      | EKAbs (_, e) => noRec e
+                                      | EKApp (e1, _) => tailOnly e1
+
+                                      | ERecord xes => List.all (noRec o #2) xes
+                                      | EField (e1, _, _) => noRec e1
+                                      | EConcat (e1, _, e2, _) => noRec e1 andalso noRec e2
+                                      | ECut (e1, _, _) => noRec e1
+                                      | ECutMulti (e1, _, _) => noRec e1
+
+                                      | ECase (e1, pes, _) => noRec e1 andalso List.all (tailOnly o #2) pes
+
+                                      | EWrite e1 => noRec e1
+
+                                      | EClosure (_, es) => List.all noRec es
+
+                                      | ELet (_, _, e1, e2) => noRec e1 andalso tailOnly e2
+
+                                      | EServerCall (_, es, (EAbs (_, _, _, e), _), _, _) =>
+                                        List.all noRec es andalso tailOnly e
+                                      | EServerCall (_, es, e, _, _) => List.all noRec es andalso noRec e
+
+                                      | ETailCall _ => raise Fail "Rpcify: ETailCall too early"
+
+                                fun tailOnlyF e =
+                                    case #1 e of
+                                        EAbs (_, _, _, e) => tailOnlyF e
+                                      | ECAbs (_, _, e) => tailOnlyF e
+                                      | EKAbs (_, e) => tailOnlyF e
+                                      | _ => tailOnly e
+
+                                val nonTail = foldl (fn ((_, n, _, e, _), nonTail) =>
+                                                        if tailOnlyF e then
+                                                            nonTail
+                                                        else
+                                                            IS.add (nonTail, n)) IS.empty vis
+                            in
+                                if IS.isEmpty nonTail then
+                                    (d, {exported = #exported st,
+                                         export_decls = #export_decls st,
+                                         cpsed = #cpsed st,
+                                         rpc = IS.union (#rpc st, all)})
+                                else
+                                    let
+                                        val rpc = foldl (fn ((_, n, _, _, _), rpc) =>
+                                                            IS.add (rpc, n)) (#rpc st) vis
+
+                                        val (cpsed, vis') =
+                                            foldl (fn (vi as (x, n, t, e, s), (cpsed, vis')) =>
+                                                      if IS.member (nonTail, n) then
+                                                          let
+                                                              fun getArgs (t, acc) =
+                                                                  case #1 t of
+                                                                      TFun (dom, ran) =>
+                                                                      getArgs (ran, dom :: acc)
+                                                                    | _ => (rev acc, t)
+                                                              val (ts, ran) = getArgs (t, [])
+                                                              val ran = case #1 ran of
+                                                                            CApp (_, ran) => ran
+                                                                          | _ => raise Fail "Rpcify: Tail function not transactional"
+                                                              val len = length ts
+
+                                                              val loc = #2 e
+                                                              val args = ListUtil.mapi
+                                                                             (fn (i, _) =>
+                                                                                 (ERel (len - i - 1), loc))
+                                                                             ts
+                                                              val k = (EAbs ("x", ran, ran, (ERel 0, loc)), loc)
+                                                              val re = (ETailCall (n, args, k, ran, ran), loc)
+                                                              val (re, _) = foldr (fn (dom, (re, ran)) =>
+                                                                                      ((EAbs ("x", dom, ran, re),
+                                                                                        loc),
+                                                                                       (TFun (dom, ran), loc)))
+                                                                                  (re, ran) ts
+
+                                                              val be = multiLiftExpInExp (len + 1) e
+                                                              val be = ListUtil.foldli
+                                                                           (fn (i, _, be) =>
+                                                                               (EApp (be, (ERel (len - i), loc)), loc))
+                                                                           be ts
+                                                              val ne = (EFfi ("Basis", "bind"), loc)
+                                                              val trans = (CFfi ("Basis", "transaction"), loc)
+                                                              val ne = (ECApp (ne, trans), loc)
+                                                              val ne = (ECApp (ne, ran), loc)
+                                                              val unit = (TRecord (CRecord ((KType, loc), []),
+                                                                                   loc), loc)
+                                                              val ne = (ECApp (ne, unit), loc)
+                                                              val ne = (EApp (ne, (EFfi ("Basis", "transaction_monad"),
+                                                                                   loc)), loc)
+                                                              val ne = (EApp (ne, be), loc)
+                                                              val ne = (EApp (ne, (ERel 0, loc)), loc)
+                                                              val tunit = (CApp (trans, unit), loc)
+                                                              val kt = (TFun (ran, tunit), loc)
+                                                              val ne = (EAbs ("k", kt, tunit, ne), loc)
+                                                              val (ne, res) = foldr (fn (dom, (ne, ran)) =>
+                                                                                        ((EAbs ("x", dom, ran, ne), loc),
+                                                                                         (TFun (dom, ran), loc)))
+                                                                                    (ne, (TFun (kt, tunit), loc)) ts
+                                                          in
+                                                              (IM.insert (cpsed, n, #1 re),
+                                                               (x, n, res, ne, s) :: vis')
+                                                          end
+                                                      else
+                                                          (cpsed, vi :: vis'))
+                                                  (#cpsed st, []) vis
+                                    in
+                                        ((DValRec (rev vis'), ErrorMsg.dummySpan),
+                                         {exported = #exported st,
+                                          export_decls = #export_decls st,
+                                          cpsed = cpsed,
+                                          rpc = rpc})
+                                    end
+                            end
+                        else
+                            (d, st)
+                      | DVal (x, n, t, e, s) =>
+                        (d,
+                         {exported = #exported st,
+                          export_decls = #export_decls st,
+                          cpsed = #cpsed st,
+                          rpc = if makesServerCall e then
+                                    IS.add (#rpc st, n)
+                                else
+                                    #rpc st})
+                      | _ => (d, st)
+
                 val (d, st) = U.Decl.foldMap {kind = fn x => x,
                                               con = fn x => x,
                                               exp = exp,
@@ -151,12 +326,16 @@
             in
                 (#export_decls st @ [d],
                  {exported = #exported st,
-                  export_decls = []})
+                  export_decls = [],
+                  cpsed = #cpsed st,
+                  rpc = #rpc st})
             end
 
         val (file, _) = ListUtil.foldlMapConcat decl
                         {exported = IS.empty,
-                         export_decls = []}
+                         export_decls = [],
+                         cpsed = IM.empty,
+                         rpc = rpcBaseIds}
                         file
     in
         file