changeset 954:2a50da66ffd8

Basic tail recursion introduction seems to be working
author Adam Chlipala <adamc@hcoop.net>
date Thu, 17 Sep 2009 16:35:11 -0400
parents 301530da2062
children 01a4d936395a
files demo/more/dlist.ur demo/more/dlist.urs demo/more/grid.ur src/core.sml src/core_print.sml src/core_untangle.sml src/core_util.sml src/monoize.sml src/reduce.sml src/reduce_local.sml src/rpcify.sml src/shake.sml tests/tail.ur tests/tail.urp tests/tail.urs
diffstat 15 files changed, 293 insertions(+), 23 deletions(-) [+]
line wrap: on
line diff
--- a/demo/more/dlist.ur	Thu Sep 17 14:57:38 2009 -0400
+++ b/demo/more/dlist.ur	Thu Sep 17 16:35:11 2009 -0400
@@ -48,6 +48,24 @@
         set tl new;
         return (tailPos cur new tl)
 
+fun replace [t] dl ls =
+    case ls of
+        [] => set dl Empty
+      | x :: ls =>
+        tl <- source Nil;
+        let
+            fun build ls acc =
+                case ls of
+                    [] => return acc
+                  | x :: ls =>
+                    this <- source (Cons (x, tl));
+                    build ls this
+        in
+            hd <- build (List.rev ls) tl;
+            tlS <- source tl;
+            set dl (Nonempty {Head = Cons (x, hd), Tail = tlS})
+        end
+
 fun renderDyn [ctx] [ctx ~ body] [t] (f : t -> position -> xml (ctx ++ body) [] []) filter dl = <xml>
   <dyn signal={dl' <- signal dl;
                return (case dl' of
--- a/demo/more/dlist.urs	Thu Sep 17 14:57:38 2009 -0400
+++ b/demo/more/dlist.urs	Thu Sep 17 16:35:11 2009 -0400
@@ -4,6 +4,8 @@
 val create : t ::: Type -> transaction (dlist t)
 val clear : t ::: Type -> dlist t -> transaction unit
 val append : t ::: Type -> dlist t -> t -> transaction position
+val replace : t ::: Type -> dlist t -> list t -> transaction unit
+
 val delete : position -> transaction unit
 val elements : t ::: Type -> dlist t -> signal (list t)
 val foldl : t ::: Type -> acc ::: Type -> (t -> acc -> signal acc) -> acc -> dlist t -> signal acc
--- a/demo/more/grid.ur	Thu Sep 17 14:57:38 2009 -0400
+++ b/demo/more/grid.ur	Thu Sep 17 16:35:11 2009 -0400
@@ -59,16 +59,20 @@
                  Selection : source bool,
                  Filters : $(map thd3 M.cols)}
 
-    fun addRow cols rows row =
+    fun newRow cols row =
         rowS <- source row;
         cols <- makeAll cols row;
         colsS <- source cols;
         ud <- source False;
         sd <- source False;
-        Monad.ignore (Dlist.append rows {Row = rowS,
-                                         Cols = colsS,
-                                         Updating = ud,
-                                         Selected = sd})
+        return {Row = rowS,
+                Cols = colsS,
+                Updating = ud,
+                Selected = sd}
+
+    fun addRow cols rows row =
+        r <- newRow cols row;
+        Monad.ignore (Dlist.append rows r)
 
     val grid =
         cols <- Monad.mapR [colMeta M.row] [fst3]
@@ -91,7 +95,8 @@
     fun sync {Cols = cols, Rows = rows, ...} =
         Dlist.clear rows;
         init <- rpc M.list;
-        List.app (addRow cols rows) init
+        rs <- List.mapM (newRow cols) init;
+        Dlist.replace rows rs
 
     fun render grid = <xml>
       <table class={tabl}>
--- a/src/core.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/core.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -116,6 +116,7 @@
        | ELet of string * con * exp * exp
 
        | EServerCall of int * exp list * exp * con * con
+       | ETailCall of int * exp list * exp * con * con
 
 withtype exp = exp' located
 
--- a/src/core_print.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/core_print.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -446,6 +446,14 @@
                                              string ")[",
                                              p_exp env e,
                                              string "]"]
+      | ETailCall (n, es, e, _, _) => box [string "Tail(",
+                                           p_enamed env n,
+                                           string ",",
+                                           space,
+                                           p_list (p_exp env) es,
+                                           string ")[",
+                                           p_exp env e,
+                                           string "]"]
 
       | EKAbs (x, e) => box [string x,
                              space,
--- a/src/core_untangle.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/core_untangle.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -38,19 +38,20 @@
 fun default (k, s) = s
 
 fun exp thisGroup (e, s) =
-    case e of
-        ENamed n =>
-        if IS.member (thisGroup, n) then
-            IS.add (s, n)
-        else
-            s
-      | EClosure (n, _) =>
-        if IS.member (thisGroup, n) then
-            IS.add (s, n)
-        else
-            s
-
-      | _ => s
+    let
+        fun try n =
+            if IS.member (thisGroup, n) then
+                IS.add (s, n)
+            else
+                s
+    in
+        case e of
+            ENamed n => try n
+          | EClosure (n, _) => try n
+          | EServerCall (n, _, _, _, _) => try n
+          | ETailCall (n, _, _, _, _) => try n
+          | _ => s
+    end
 
 fun untangle file =
     let
--- a/src/core_util.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/core_util.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -539,6 +539,13 @@
       | (EServerCall _, _) => LESS
       | (_, EServerCall _) => GREATER
 
+      | (ETailCall (n1, es1, e1, _, _), ETailCall (n2, es2, e2, _, _)) =>
+        join (Int.compare (n1, n2),
+              fn () => join (joinL compare (es1, es2),
+                             fn () => compare (e1, e2)))
+      | (ETailCall _, _) => LESS
+      | (_, ETailCall _) => GREATER
+
       | (EKAbs (_, e1), EKAbs (_, e2)) => compare (e1, e2)
       | (EKAbs _, _) => LESS
       | (_, EKAbs _) => GREATER
@@ -729,6 +736,17 @@
                                                   fn t2' =>
                                                      (EServerCall (n, es', e', t1', t2'), loc)))))
 
+              | ETailCall (n, es, e, t1, t2) =>
+                S.bind2 (ListUtil.mapfold (mfe ctx) es,
+                      fn es' =>
+                         S.bind2 (mfe ctx e,
+                                 fn e' =>
+                                    S.bind2 (mfc ctx t1,
+                                          fn t1' =>
+                                             S.map2 (mfc ctx t2,
+                                                  fn t2' =>
+                                                     (ETailCall (n, es', e', t1', t2'), loc)))))
+
               | EKAbs (x, e) =>
                 S.map2 (mfe (bind (ctx, RelK x)) e,
                         fn e' =>
--- a/src/monoize.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/monoize.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -3137,6 +3137,21 @@
                 ((L'.ELet (x, t', e1, e2), loc), fm)
             end
 
+          | L.ETailCall (n, es, ek, _, (L.TRecord (L.CRecord (_, []), _), _)) =>
+            let
+                val (es, fm) = ListUtil.foldlMap (fn (e, fm) => monoExp (env, st, fm) e) fm es
+                val (ek, fm) = monoExp (env, st, fm) ek
+
+                val e = (L'.ENamed n, loc)
+                val e = foldl (fn (e, arg) => (L'.EApp (e, arg), loc)) e es
+                val e = (L'.EApp (e, ek), loc)
+            in
+                (e, fm)
+            end
+          | L.ETailCall _ => (E.errorAt loc "Full scope of tail call continuation isn't known";
+                              Print.eprefaces' [("Expression", CorePrint.p_exp env all)];
+                              (dummyExp, fm))
+
           | L.EServerCall (n, es, ek, t, (L.TRecord (L.CRecord (_, []), _), _)) =>
             let
                 val t = monoType env t
--- a/src/reduce.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/reduce.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -745,6 +745,8 @@
 
                           | EServerCall (n, es, e, t1, t2) => (EServerCall (n, map (exp env) es, exp env e,
                                                                             con env t1, con env t2), loc)
+                          | ETailCall (n, es, e, t1, t2) => (ETailCall (n, map (exp env) es, exp env e,
+                                                                        con env t1, con env t2), loc)
             in
                 (*if dangling (edepth' (deKnown env)) r then
                     (Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty all),
--- a/src/reduce_local.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/reduce_local.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -140,6 +140,7 @@
       | ELet (x, t, e1, e2) => (ELet (x, t, exp env e1, exp (Unknown :: env) e2), loc)
 
       | EServerCall (n, es, e, t1, t2) => (EServerCall (n, map (exp env) es, exp env e, t1, t2), loc)
+      | ETailCall (n, es, e, t1, t2) => (ETailCall (n, map (exp env) es, exp env e, t1, t2), loc)
 
 fun reduce file =
     let
--- a/src/rpcify.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/rpcify.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -32,6 +32,12 @@
 structure U = CoreUtil
 structure E = CoreEnv
 
+fun multiLiftExpInExp n e =
+    if n = 0 then
+        e
+    else
+        multiLiftExpInExp (n - 1) (E.liftExpInExp 0 e)
+
 structure IS = IntBinarySet
 structure IM = IntBinaryMap
 
@@ -42,7 +48,10 @@
 
 type state = {
      exported : IS.set,
-     export_decls : decl list
+     export_decls : decl list,
+
+     cpsed : exp' IM.map,
+     rpc : IS.set
 }
 
 fun frob file =
@@ -115,7 +124,9 @@
                                          (DExport (Rpc ReadWrite, n), loc) :: #export_decls st)
 
                                 val st = {exported = exported,
-                                          export_decls = export_decls}
+                                          export_decls = export_decls,
+                                          cpsed = #cpsed st,
+                                          rpc = #rpc st}
 
                                 val k = (ECApp ((EFfi ("Basis", "return"), loc),
                                                 (CFfi ("Basis", "transaction"), loc)), loc)
@@ -134,6 +145,11 @@
                     else
                         (e, st)
 
+                  | ENamed n =>
+                    (case IM.find (#cpsed st, n) of
+                         NONE => (e, st)
+                       | SOME re => (re, st))
+
                   | _ => (e, st)
             end
 
@@ -143,6 +159,165 @@
 
         fun decl (d, st : state) =
             let
+                val makesServerCall = U.Exp.exists {kind = fn _ => false,
+                                                    con = fn _ => false,
+                                                    exp = fn EFfi ("Basis", "rpc") => true
+                                                           | ENamed n => IS.member (#rpc st, n)
+                                                           | _ => false}
+
+                val (d, st) =
+                    case #1 d of
+                        DValRec vis =>
+                        if List.exists (fn (_, _, _, e, _) => makesServerCall e) vis then
+                            let
+                                val all = foldl (fn ((_, n, _, _, _), all) => IS.add (all, n)) IS.empty vis
+
+                                val usesRec = U.Exp.exists {kind = fn _ => false,
+                                                            con = fn _ => false,
+                                                            exp = fn ENamed n => IS.member (all, n)
+                                                                   | _ => false}
+
+                                val noRec = not o usesRec
+
+                                fun tailOnly (e, _) =
+                                    case e of
+                                        EPrim _ => true
+                                      | ERel _ => true
+                                      | ENamed _ => true
+                                      | ECon (_, _, _, SOME e) => noRec e
+                                      | ECon _ => true
+                                      | EFfi _ => true
+                                      | EFfiApp (_, _, es) => List.all noRec es
+                                      | EApp (e1, e2) => noRec e2 andalso tailOnly e1
+                                      | EAbs (_, _, _, e) => noRec e
+                                      | ECApp (e1, _) => tailOnly e1
+                                      | ECAbs (_, _, e) => noRec e
+
+                                      | EKAbs (_, e) => noRec e
+                                      | EKApp (e1, _) => tailOnly e1
+
+                                      | ERecord xes => List.all (noRec o #2) xes
+                                      | EField (e1, _, _) => noRec e1
+                                      | EConcat (e1, _, e2, _) => noRec e1 andalso noRec e2
+                                      | ECut (e1, _, _) => noRec e1
+                                      | ECutMulti (e1, _, _) => noRec e1
+
+                                      | ECase (e1, pes, _) => noRec e1 andalso List.all (tailOnly o #2) pes
+
+                                      | EWrite e1 => noRec e1
+
+                                      | EClosure (_, es) => List.all noRec es
+
+                                      | ELet (_, _, e1, e2) => noRec e1 andalso tailOnly e2
+
+                                      | EServerCall (_, es, (EAbs (_, _, _, e), _), _, _) =>
+                                        List.all noRec es andalso tailOnly e
+                                      | EServerCall (_, es, e, _, _) => List.all noRec es andalso noRec e
+
+                                      | ETailCall _ => raise Fail "Rpcify: ETailCall too early"
+
+                                fun tailOnlyF e =
+                                    case #1 e of
+                                        EAbs (_, _, _, e) => tailOnlyF e
+                                      | ECAbs (_, _, e) => tailOnlyF e
+                                      | EKAbs (_, e) => tailOnlyF e
+                                      | _ => tailOnly e
+
+                                val nonTail = foldl (fn ((_, n, _, e, _), nonTail) =>
+                                                        if tailOnlyF e then
+                                                            nonTail
+                                                        else
+                                                            IS.add (nonTail, n)) IS.empty vis
+                            in
+                                if IS.isEmpty nonTail then
+                                    (d, {exported = #exported st,
+                                         export_decls = #export_decls st,
+                                         cpsed = #cpsed st,
+                                         rpc = IS.union (#rpc st, all)})
+                                else
+                                    let
+                                        val rpc = foldl (fn ((_, n, _, _, _), rpc) =>
+                                                            IS.add (rpc, n)) (#rpc st) vis
+
+                                        val (cpsed, vis') =
+                                            foldl (fn (vi as (x, n, t, e, s), (cpsed, vis')) =>
+                                                      if IS.member (nonTail, n) then
+                                                          let
+                                                              fun getArgs (t, acc) =
+                                                                  case #1 t of
+                                                                      TFun (dom, ran) =>
+                                                                      getArgs (ran, dom :: acc)
+                                                                    | _ => (rev acc, t)
+                                                              val (ts, ran) = getArgs (t, [])
+                                                              val ran = case #1 ran of
+                                                                            CApp (_, ran) => ran
+                                                                          | _ => raise Fail "Rpcify: Tail function not transactional"
+                                                              val len = length ts
+
+                                                              val loc = #2 e
+                                                              val args = ListUtil.mapi
+                                                                             (fn (i, _) =>
+                                                                                 (ERel (len - i - 1), loc))
+                                                                             ts
+                                                              val k = (EAbs ("x", ran, ran, (ERel 0, loc)), loc)
+                                                              val re = (ETailCall (n, args, k, ran, ran), loc)
+                                                              val (re, _) = foldr (fn (dom, (re, ran)) =>
+                                                                                      ((EAbs ("x", dom, ran, re),
+                                                                                        loc),
+                                                                                       (TFun (dom, ran), loc)))
+                                                                                  (re, ran) ts
+
+                                                              val be = multiLiftExpInExp (len + 1) e
+                                                              val be = ListUtil.foldli
+                                                                           (fn (i, _, be) =>
+                                                                               (EApp (be, (ERel (len - i), loc)), loc))
+                                                                           be ts
+                                                              val ne = (EFfi ("Basis", "bind"), loc)
+                                                              val trans = (CFfi ("Basis", "transaction"), loc)
+                                                              val ne = (ECApp (ne, trans), loc)
+                                                              val ne = (ECApp (ne, ran), loc)
+                                                              val unit = (TRecord (CRecord ((KType, loc), []),
+                                                                                   loc), loc)
+                                                              val ne = (ECApp (ne, unit), loc)
+                                                              val ne = (EApp (ne, (EFfi ("Basis", "transaction_monad"),
+                                                                                   loc)), loc)
+                                                              val ne = (EApp (ne, be), loc)
+                                                              val ne = (EApp (ne, (ERel 0, loc)), loc)
+                                                              val tunit = (CApp (trans, unit), loc)
+                                                              val kt = (TFun (ran, tunit), loc)
+                                                              val ne = (EAbs ("k", kt, tunit, ne), loc)
+                                                              val (ne, res) = foldr (fn (dom, (ne, ran)) =>
+                                                                                        ((EAbs ("x", dom, ran, ne), loc),
+                                                                                         (TFun (dom, ran), loc)))
+                                                                                    (ne, (TFun (kt, tunit), loc)) ts
+                                                          in
+                                                              (IM.insert (cpsed, n, #1 re),
+                                                               (x, n, res, ne, s) :: vis')
+                                                          end
+                                                      else
+                                                          (cpsed, vi :: vis'))
+                                                  (#cpsed st, []) vis
+                                    in
+                                        ((DValRec (rev vis'), ErrorMsg.dummySpan),
+                                         {exported = #exported st,
+                                          export_decls = #export_decls st,
+                                          cpsed = cpsed,
+                                          rpc = rpc})
+                                    end
+                            end
+                        else
+                            (d, st)
+                      | DVal (x, n, t, e, s) =>
+                        (d,
+                         {exported = #exported st,
+                          export_decls = #export_decls st,
+                          cpsed = #cpsed st,
+                          rpc = if makesServerCall e then
+                                    IS.add (#rpc st, n)
+                                else
+                                    #rpc st})
+                      | _ => (d, st)
+
                 val (d, st) = U.Decl.foldMap {kind = fn x => x,
                                               con = fn x => x,
                                               exp = exp,
@@ -151,12 +326,16 @@
             in
                 (#export_decls st @ [d],
                  {exported = #exported st,
-                  export_decls = []})
+                  export_decls = [],
+                  cpsed = #cpsed st,
+                  rpc = #rpc st})
             end
 
         val (file, _) = ListUtil.foldlMapConcat decl
                         {exported = IS.empty,
-                         export_decls = []}
+                         export_decls = [],
+                         cpsed = IM.empty,
+                         rpc = rpcBaseIds}
                         file
     in
         file
--- a/src/shake.sml	Thu Sep 17 14:57:38 2009 -0400
+++ b/src/shake.sml	Thu Sep 17 16:35:11 2009 -0400
@@ -138,6 +138,7 @@
                 case e of
                     ENamed n => check n
                   | EServerCall (n, _, _, _, _) => check n
+                  | ETailCall (n, _, _, _, _) => check n
                   | _ => s
             end
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/tail.ur	Thu Sep 17 16:35:11 2009 -0400
@@ -0,0 +1,15 @@
+fun one () = return 1
+
+fun addEm n =
+    if n = 0 then
+        return 0
+    else
+        n1 <- rpc (one ());
+        n2 <- addEm (n - 1);
+        return (n1 + n2)
+
+fun main () =
+    s <- source 0;
+    return <xml><body onload={n <- addEm 3; set s n}>
+      <dyn signal={n <- signal s; return (txt n)}/>
+    </body></xml>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/tail.urp	Thu Sep 17 16:35:11 2009 -0400
@@ -0,0 +1,3 @@
+debug
+
+tail
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/tail.urs	Thu Sep 17 16:35:11 2009 -0400
@@ -0,0 +1,1 @@
+val main : unit -> transaction page