changeset 469:b393c2fc80f8

About to begin optimization of recursive transaction functions
author Adam Chlipala <adamc@hcoop.net>
date Thu, 06 Nov 2008 17:09:53 -0500
parents 4efab85405be
children 7cb418e9714f
files demo/ref.ur demo/tree.ur demo/tree.urp demo/tree.urs demo/treeFun.ur demo/treeFun.urs lib/basis.urs lib/top.ur lib/top.urs src/elaborate.sml src/especialize.sml
diffstat 11 files changed, 228 insertions(+), 79 deletions(-) [+]
line wrap: on
line diff
--- a/demo/ref.ur	Thu Nov 06 15:52:13 2008 -0500
+++ b/demo/ref.ur	Thu Nov 06 17:09:53 2008 -0500
@@ -1,11 +1,9 @@
 structure IR = RefFun.Make(struct
                                type t = int
-                               val inj = _
                            end)
 
 structure SR = RefFun.Make(struct
                                type t = string
-                               val inj = _
                            end)
 
 fun main () =
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/demo/tree.ur	Thu Nov 06 17:09:53 2008 -0500
@@ -0,0 +1,15 @@
+table t : { Id : int, Parent : option int, Nam : string }
+
+open TreeFun.Make(struct
+                      val tab = t
+                  end)
+
+fun row r = <xml>
+  #{[r.Id]}: {[r.Nam]}
+</xml>
+
+fun main () =
+    xml <- tree row None;
+    return <xml><body>
+      {xml}
+    </body></xml>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/demo/tree.urp	Thu Nov 06 17:09:53 2008 -0500
@@ -0,0 +1,6 @@
+debug
+database dbname=tree
+sql tree.sql
+
+treeFun
+tree
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/demo/tree.urs	Thu Nov 06 17:09:53 2008 -0500
@@ -0,0 +1,1 @@
+val main : unit -> transaction page
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/demo/treeFun.ur	Thu Nov 06 17:09:53 2008 -0500
@@ -0,0 +1,35 @@
+functor Make(M : sig
+                 type key
+                 con id :: Name
+                 con parent :: Name
+                 con cols :: {Type}
+                 constraint [id] ~ [parent]
+                 constraint [id, parent] ~ cols
+
+                 val key_inj : sql_injectable key
+                 val option_key_inj : sql_injectable (option key)
+
+                 table tab : [id = key, parent = option key] ++ cols
+             end) = struct
+
+    open M
+
+    fun tree (f : $([id = key, parent = option key] ++ cols) -> xbody)
+             (root : option M.key) =
+        let
+            fun recurse (root : option key) =
+                queryX' (SELECT * FROM tab WHERE tab.{parent} = {root})
+                        (fn r =>
+                            children <- recurse (Some r.Tab.id);
+                            return <xml>
+                              <li> {f r.Tab}</li>
+                              
+                              <ul>
+                                {children}
+                              </ul>
+                            </xml>)
+        in
+            recurse root
+        end
+
+end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/demo/treeFun.urs	Thu Nov 06 17:09:53 2008 -0500
@@ -0,0 +1,22 @@
+functor Make(M : sig
+                 type key
+                 con id :: Name
+                 con parent :: Name
+                 con cols :: {Type}
+                 constraint [id] ~ [parent]
+                 constraint [id, parent] ~ cols
+
+                 val key_inj : sql_injectable key
+                 val option_key_inj : sql_injectable (option key)
+
+                 table tab : [id = key, parent = option key] ++ cols
+             end) : sig
+
+    con id = M.id
+    con parent = M.parent
+
+    val tree : ($([id = M.key, parent = option M.key] ++ M.cols) -> xbody)
+               -> option M.key
+               -> transaction xbody
+
+end
--- a/lib/basis.urs	Thu Nov 06 15:52:13 2008 -0500
+++ b/lib/basis.urs	Thu Nov 06 17:09:53 2008 -0500
@@ -374,7 +374,10 @@
 val h2 : bodyTag []
 val h3 : bodyTag []
 val h4 : bodyTag []
+
 val li : bodyTag []
+val ol : bodyTag []
+val ul : bodyTag []
 
 val hr : bodyTag []
 
--- a/lib/top.ur	Thu Nov 06 15:52:13 2008 -0500
+++ b/lib/top.ur	Thu Nov 06 17:09:53 2008 -0500
@@ -202,6 +202,17 @@
           (fn fs acc => return <xml>{acc}{f fs}</xml>)
           <xml/>
 
+fun queryX' (tables ::: {{Type}}) (exps ::: {Type}) (ctx ::: {Unit})
+            (q : sql_query tables exps) [tables ~ exps]
+            (f : $(exps ++ fold (fn nm (fields :: {Type}) acc [[nm] ~ acc] =>
+                                    [nm = $fields] ++ acc) [] tables)
+                 -> transaction (xml ctx [] [])) =
+    query q
+          (fn fs acc =>
+              r <- f fs;
+              return <xml>{acc}{r}</xml>)
+          <xml/>
+
 fun oneOrNoRows (tables ::: {{Type}}) (exps ::: {Type})
                 (q : sql_query tables exps) [tables ~ exps] =
     query q
--- a/lib/top.urs	Thu Nov 06 15:52:13 2008 -0500
+++ b/lib/top.urs	Thu Nov 06 17:09:53 2008 -0500
@@ -141,6 +141,14 @@
                     -> xml ctx [] [])
                    -> transaction (xml ctx [] [])
 
+val queryX' : tables ::: {{Type}} -> exps ::: {Type} -> ctx ::: {Unit}
+              -> sql_query tables exps
+              -> fn [tables ~ exps] =>
+                    ($(exps ++ fold (fn nm (fields :: {Type}) acc [[nm] ~ acc] =>
+                                        [nm = $fields] ++ acc) [] tables)
+                     -> transaction (xml ctx [] []))
+                    -> transaction (xml ctx [] [])
+                       
 val oneOrNoRows : tables ::: {{Type}} -> exps ::: {Type}
                   -> sql_query tables exps
                   -> fn [tables ~ exps] =>
--- a/src/elaborate.sml	Thu Nov 06 15:52:13 2008 -0500
+++ b/src/elaborate.sml	Thu Nov 06 17:09:53 2008 -0500
@@ -1777,6 +1777,38 @@
 fun sequenceOf () = (L'.CModProj (!basis_r, [], "sql_sequence"), ErrorMsg.dummySpan)
 fun cookieOf () = (L'.CModProj (!basis_r, [], "http_cookie"), ErrorMsg.dummySpan)
 
+fun dopenConstraints (loc, env, denv) {str, strs} =
+    case E.lookupStr env str of
+        NONE => (strError env (UnboundStr (loc, str));
+                 denv)
+      | SOME (n, sgn) =>
+        let
+            val (st, sgn) = foldl (fn (m, (str, sgn)) =>
+                                      case E.projectStr env {str = str, sgn = sgn, field = m} of
+                                          NONE => (strError env (UnboundStr (loc, m));
+                                                   (strerror, sgnerror))
+                                        | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
+                                  ((L'.StrVar n, loc), sgn) strs
+                            
+            val cso = E.projectConstraints env {sgn = sgn, str = st}
+
+            val denv = case cso of
+                           NONE => (strError env (UnboundStr (loc, str));
+                                    denv)
+                         | SOME cs => foldl (fn ((c1, c2), denv) =>
+                                                let
+                                                    val (denv, gs) = D.assert env denv (c1, c2)
+                                                in
+                                                    case gs of
+                                                        [] => ()
+                                                      | _ => raise Fail "dopenConstraints: Sub-constraints remain";
+
+                                                    denv
+                                                end) denv cs
+        in
+            denv
+        end
+
 fun elabSgn_item ((sgi, loc), (env, denv, gs)) =
     case sgi of
         L.SgiConAbs (x, k) =>
@@ -2054,7 +2086,8 @@
         let
             val (dom', gs1) = elabSgn (env, denv) dom
             val (env', n) = E.pushStrNamed env m dom'
-            val (ran', gs2) = elabSgn (env', denv) ran
+            val denv' = dopenConstraints (loc, env', denv) {str = m, strs = []}
+            val (ran', gs2) = elabSgn (env', denv') ran
         in
             ((L'.SgnFun (m, n, dom', ran'), loc), gs1 @ gs2)
         end
@@ -2193,38 +2226,6 @@
                   ([], (env, denv)))
     end
 
-fun dopenConstraints (loc, env, denv) {str, strs} =
-    case E.lookupStr env str of
-        NONE => (strError env (UnboundStr (loc, str));
-                 denv)
-      | SOME (n, sgn) =>
-        let
-            val (st, sgn) = foldl (fn (m, (str, sgn)) =>
-                                      case E.projectStr env {str = str, sgn = sgn, field = m} of
-                                          NONE => (strError env (UnboundStr (loc, m));
-                                                   (strerror, sgnerror))
-                                        | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
-                                  ((L'.StrVar n, loc), sgn) strs
-                            
-            val cso = E.projectConstraints env {sgn = sgn, str = st}
-
-            val denv = case cso of
-                           NONE => (strError env (UnboundStr (loc, str));
-                                    denv)
-                         | SOME cs => foldl (fn ((c1, c2), denv) =>
-                                                let
-                                                    val (denv, gs) = D.assert env denv (c1, c2)
-                                                in
-                                                    case gs of
-                                                        [] => ()
-                                                      | _ => raise Fail "dopenConstraints: Sub-constraints remain";
-
-                                                    denv
-                                                end) denv cs
-        in
-            denv
-        end
-
 fun sgiOfDecl (d, loc) =
     case d of
         L'.DCon (x, n, k, c) => [(L'.SgiCon (x, n, k, c), loc)]
@@ -2252,6 +2253,8 @@
       | _ => denv
 
 fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) =
+    ((*prefaces "subSgn" [("sgn1", p_sgn env sgn1),
+                        ("sgn2", p_sgn env sgn2)];*)
     case (#1 (hnormSgn env sgn1), #1 (hnormSgn env sgn2)) of
         (L'.SgnError, _) => ()
       | (_, L'.SgnError) => ()
@@ -2274,8 +2277,18 @@
                                     [] => (sgnError env (UnmatchedSgi sgi2All);
                                            (env, denv))
                                   | h :: t =>
-                                    case p h of
-                                        NONE => seek (E.sgiBinds env h, sgiBindsD (env, denv) h) t
+                                    case p (env, h) of
+                                        NONE =>
+                                        let
+                                            val env = case #1 h of
+                                                          L'.SgiCon (x, n, k, c) =>
+                                                          E.pushCNamedAs env x n k (SOME c)
+                                                        | L'.SgiConAbs (x, n, k) =>
+                                                          E.pushCNamedAs env x n k NONE
+                                                        | _ => env
+                                        in
+                                            seek (E.sgiBinds env h, sgiBindsD (env, denv) h) t
+                                        end
                                       | SOME envs => envs
                         in
                             seek (env, denv) sgis1
@@ -2283,7 +2296,7 @@
                 in
                     case sgi of
                         L'.SgiConAbs (x, n2, k2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      fun found (x', n1, k1, co1) =
                                          if x = x' then
@@ -2331,7 +2344,7 @@
                                  end)
 
                       | L'.SgiCon (x, n2, k2, c2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      fun found (x', n1, k1, c1) =
                                          if x = x' then
@@ -2365,7 +2378,7 @@
                                  end)
 
                       | L'.SgiDatatype (x, n2, xs2, xncs2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      fun found (n1, xs1, xncs1) =
                                          let
@@ -2426,7 +2439,7 @@
                                  end)
 
                       | L'.SgiDatatypeImp (x, n2, m12, ms2, s2, xs, _) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiDatatypeImp (x', n1, m11, ms1, s1, _, _) =>
                                      if x = x' then
@@ -2457,7 +2470,7 @@
                                    | _ => NONE)
 
                       | L'.SgiVal (x, n2, c2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiVal (x', n1, c1) =>
                                      if x = x' then
@@ -2474,7 +2487,7 @@
                                    | _ => NONE)
 
                       | L'.SgiStr (x, n2, sgn2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiStr (x', n1, sgn1) =>
                                      if x = x' then
@@ -2495,7 +2508,7 @@
                                    | _ => NONE)
 
                       | L'.SgiSgn (x, n2, sgn2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiSgn (x', n1, sgn1) =>
                                      if x = x' then
@@ -2516,7 +2529,7 @@
                                    | _ => NONE)
 
                       | L'.SgiConstraint (c2, d2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiConstraint (c1, d1) =>
                                      if consEq (env, denv) (c1, c2) andalso consEq (env, denv) (d1, d2) then
@@ -2534,7 +2547,7 @@
                                    | _ => NONE)
 
                       | L'.SgiClassAbs (x, n2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      fun found (x', n1, co) =
                                          if x = x' then
@@ -2557,7 +2570,7 @@
                                        | _ => NONE
                                  end)
                       | L'.SgiClass (x, n2, c2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
 
@@ -2606,7 +2619,7 @@
             subSgn (E.pushStrNamedAs env m2 n2 dom2, denv) ran1 ran2
         end
 
-      | _ => sgnError env (SgnWrongForm (sgn1, sgn2))
+      | _ => sgnError env (SgnWrongForm (sgn1, sgn2)))
 
 
 fun positive self =
@@ -2717,46 +2730,79 @@
 
                        | _ => NONE
 
-                 val (needed, constraints, _) =
-                     foldl (fn ((sgi, loc), (needed, constraints, env')) =>
+                 val (neededC, constraints, neededV, _) =
+                     foldl (fn ((sgi, loc), (neededC, constraints, neededV, env')) =>
                                let
-                                   val (needed, constraints) =
+                                   val (needed, constraints, neededV) =
                                        case sgi of
-                                           L'.SgiConAbs (x, _, _) => (SS.add (needed, x), constraints)
-                                         | L'.SgiConstraint cs => (needed, (env', cs, loc) :: constraints)
-                                         | _ => (needed, constraints)
+                                           L'.SgiConAbs (x, _, _) => (SS.add (neededC, x), constraints, neededV)
+                                         | L'.SgiConstraint cs => (neededC, (env', cs, loc) :: constraints, neededV)
+
+                                         | L'.SgiVal (x, _, t) =>
+                                           let
+                                               fun default () = (neededC, constraints, neededV)
+
+                                               val t = normClassConstraint env' t
+                                           in
+                                               case #1 t of
+                                                   L'.CApp (f, _) =>
+                                                   if E.isClass env' f then
+                                                       (neededC, constraints, SS.add (neededV, x))
+                                                   else
+                                                       default ()
+                                                       
+                                                 | _ => default ()
+                                           end
+
+                                         | _ => (neededC, constraints, neededV)
                                in
-                                   (needed, constraints, E.sgiBinds env' (sgi, loc))
+                                   (needed, constraints, neededV, E.sgiBinds env' (sgi, loc))
                                end)
-                           (SS.empty, [], env) sgis
+                           (SS.empty, [], SS.empty, env) sgis
                                                               
-                 val needed = foldl (fn ((d, _), needed) =>
-                                        case d of
-                                            L.DCon (x, _, _) => (SS.delete (needed, x)
-                                                                 handle NotFound =>
-                                                                        needed)
-                                          | L.DClass (x, _) => (SS.delete (needed, x)
-                                                                handle NotFound => needed)
-                                          | L.DOpen _ => SS.empty
-                                          | _ => needed)
-                                    needed ds
-
-                 val cds = List.mapPartial (fn (env', (c1, c2), loc) =>
+                 val (neededC, neededV) = foldl (fn ((d, _), needed as (neededC, neededV)) =>
+                                                    case d of
+                                                        L.DCon (x, _, _) => ((SS.delete (neededC, x), neededV)
+                                                                             handle NotFound =>
+                                                                                    needed)
+                                                      | L.DClass (x, _) => ((SS.delete (neededC, x), neededV)
+                                                                            handle NotFound => needed)
+                                                      | L.DVal (x, _, _) => ((neededC, SS.delete (neededV, x))
+                                                                             handle NotFound => needed)
+                                                      | L.DOpen _ => (SS.empty, SS.empty)
+                                                      | _ => needed)
+                                                (neededC, neededV) ds
+
+                 val ds' = List.mapPartial (fn (env', (c1, c2), loc) =>
                                                case (decompileCon env' c1, decompileCon env' c2) of
                                                    (SOME c1, SOME c2) =>
                                                    SOME (L.DConstraint (c1, c2), loc)
                                                  | _ => NONE) constraints
+
+                 val ds' =
+                     case SS.listItems neededV of
+                         [] => ds'
+                       | xs =>
+                         let
+                             val ewild = (L.EWild, #2 str)
+                             val ds'' = map (fn x => (L.DVal (x, NONE, ewild), #2 str)) xs
+                         in
+                             ds'' @ ds'
+                         end
+
+                 val ds' =
+                     case SS.listItems neededC of
+                         [] => ds'
+                       | xs =>
+                         let
+                             val kwild = (L.KWild, #2 str)
+                             val cwild = (L.CWild kwild, #2 str)
+                             val ds'' = map (fn x => (L.DCon (x, NONE, cwild), #2 str)) xs
+                         in
+                             ds'' @ ds'
+                         end
              in
-                 case SS.listItems needed of
-                     [] => (L.StrConst (ds @ cds), #2 str)
-                   | xs =>
-                     let
-                         val kwild = (L.KWild, #2 str)
-                         val cwild = (L.CWild kwild, #2 str)
-                         val ds' = map (fn x => (L.DCon (x, NONE, cwild), #2 str)) xs
-                     in
-                         (L.StrConst (ds @ ds' @ cds), #2 str)
-                     end
+                 (L.StrConst (ds @ ds'), #2 str)
              end
            | _ => str)
       | _ => str
--- a/src/especialize.sml	Thu Nov 06 15:52:13 2008 -0500
+++ b/src/especialize.sml	Thu Nov 06 17:09:53 2008 -0500
@@ -110,7 +110,7 @@
           | SOME (_, [], _) => (e, st)
           | SOME (f, xs, xs') =>
             case IM.find (#funcs st, f) of
-                NONE => ((*print "SHOT DOWN!\n";*) (e, st))
+                NONE => ((*print ("SHOT DOWN! " ^ Int.toString f ^ "\n");*) (e, st))
               | SOME {name, args, body, typ, tag} =>
                 case KM.find (args, xs) of
                     SOME f' => ((*Print.prefaces "Pre-existing" [("e", CorePrint.p_exp CoreEnv.empty (e, ErrorMsg.dummySpan))];*)
@@ -203,6 +203,10 @@
                                               body = e,
                                               typ = c,
                                               tag = tag})
+                      | DVal (_, n, _, (ENamed n', _), _) =>
+                        (case IM.find (funcs, n') of
+                             NONE => funcs
+                           | SOME v => IM.insert (funcs, n, v))
                       | _ => funcs
 
                 val (changed, ds) =