diff src/elaborate.sml @ 469:b393c2fc80f8

About to begin optimization of recursive transaction functions
author Adam Chlipala <adamc@hcoop.net>
date Thu, 06 Nov 2008 17:09:53 -0500
parents 3f1b9231a37b
children 20fab0e96217
line wrap: on
line diff
--- a/src/elaborate.sml	Thu Nov 06 15:52:13 2008 -0500
+++ b/src/elaborate.sml	Thu Nov 06 17:09:53 2008 -0500
@@ -1777,6 +1777,38 @@
 fun sequenceOf () = (L'.CModProj (!basis_r, [], "sql_sequence"), ErrorMsg.dummySpan)
 fun cookieOf () = (L'.CModProj (!basis_r, [], "http_cookie"), ErrorMsg.dummySpan)
 
+fun dopenConstraints (loc, env, denv) {str, strs} =
+    case E.lookupStr env str of
+        NONE => (strError env (UnboundStr (loc, str));
+                 denv)
+      | SOME (n, sgn) =>
+        let
+            val (st, sgn) = foldl (fn (m, (str, sgn)) =>
+                                      case E.projectStr env {str = str, sgn = sgn, field = m} of
+                                          NONE => (strError env (UnboundStr (loc, m));
+                                                   (strerror, sgnerror))
+                                        | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
+                                  ((L'.StrVar n, loc), sgn) strs
+                            
+            val cso = E.projectConstraints env {sgn = sgn, str = st}
+
+            val denv = case cso of
+                           NONE => (strError env (UnboundStr (loc, str));
+                                    denv)
+                         | SOME cs => foldl (fn ((c1, c2), denv) =>
+                                                let
+                                                    val (denv, gs) = D.assert env denv (c1, c2)
+                                                in
+                                                    case gs of
+                                                        [] => ()
+                                                      | _ => raise Fail "dopenConstraints: Sub-constraints remain";
+
+                                                    denv
+                                                end) denv cs
+        in
+            denv
+        end
+
 fun elabSgn_item ((sgi, loc), (env, denv, gs)) =
     case sgi of
         L.SgiConAbs (x, k) =>
@@ -2054,7 +2086,8 @@
         let
             val (dom', gs1) = elabSgn (env, denv) dom
             val (env', n) = E.pushStrNamed env m dom'
-            val (ran', gs2) = elabSgn (env', denv) ran
+            val denv' = dopenConstraints (loc, env', denv) {str = m, strs = []}
+            val (ran', gs2) = elabSgn (env', denv') ran
         in
             ((L'.SgnFun (m, n, dom', ran'), loc), gs1 @ gs2)
         end
@@ -2193,38 +2226,6 @@
                   ([], (env, denv)))
     end
 
-fun dopenConstraints (loc, env, denv) {str, strs} =
-    case E.lookupStr env str of
-        NONE => (strError env (UnboundStr (loc, str));
-                 denv)
-      | SOME (n, sgn) =>
-        let
-            val (st, sgn) = foldl (fn (m, (str, sgn)) =>
-                                      case E.projectStr env {str = str, sgn = sgn, field = m} of
-                                          NONE => (strError env (UnboundStr (loc, m));
-                                                   (strerror, sgnerror))
-                                        | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
-                                  ((L'.StrVar n, loc), sgn) strs
-                            
-            val cso = E.projectConstraints env {sgn = sgn, str = st}
-
-            val denv = case cso of
-                           NONE => (strError env (UnboundStr (loc, str));
-                                    denv)
-                         | SOME cs => foldl (fn ((c1, c2), denv) =>
-                                                let
-                                                    val (denv, gs) = D.assert env denv (c1, c2)
-                                                in
-                                                    case gs of
-                                                        [] => ()
-                                                      | _ => raise Fail "dopenConstraints: Sub-constraints remain";
-
-                                                    denv
-                                                end) denv cs
-        in
-            denv
-        end
-
 fun sgiOfDecl (d, loc) =
     case d of
         L'.DCon (x, n, k, c) => [(L'.SgiCon (x, n, k, c), loc)]
@@ -2252,6 +2253,8 @@
       | _ => denv
 
 fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) =
+    ((*prefaces "subSgn" [("sgn1", p_sgn env sgn1),
+                        ("sgn2", p_sgn env sgn2)];*)
     case (#1 (hnormSgn env sgn1), #1 (hnormSgn env sgn2)) of
         (L'.SgnError, _) => ()
       | (_, L'.SgnError) => ()
@@ -2274,8 +2277,18 @@
                                     [] => (sgnError env (UnmatchedSgi sgi2All);
                                            (env, denv))
                                   | h :: t =>
-                                    case p h of
-                                        NONE => seek (E.sgiBinds env h, sgiBindsD (env, denv) h) t
+                                    case p (env, h) of
+                                        NONE =>
+                                        let
+                                            val env = case #1 h of
+                                                          L'.SgiCon (x, n, k, c) =>
+                                                          E.pushCNamedAs env x n k (SOME c)
+                                                        | L'.SgiConAbs (x, n, k) =>
+                                                          E.pushCNamedAs env x n k NONE
+                                                        | _ => env
+                                        in
+                                            seek (E.sgiBinds env h, sgiBindsD (env, denv) h) t
+                                        end
                                       | SOME envs => envs
                         in
                             seek (env, denv) sgis1
@@ -2283,7 +2296,7 @@
                 in
                     case sgi of
                         L'.SgiConAbs (x, n2, k2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      fun found (x', n1, k1, co1) =
                                          if x = x' then
@@ -2331,7 +2344,7 @@
                                  end)
 
                       | L'.SgiCon (x, n2, k2, c2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      fun found (x', n1, k1, c1) =
                                          if x = x' then
@@ -2365,7 +2378,7 @@
                                  end)
 
                       | L'.SgiDatatype (x, n2, xs2, xncs2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      fun found (n1, xs1, xncs1) =
                                          let
@@ -2426,7 +2439,7 @@
                                  end)
 
                       | L'.SgiDatatypeImp (x, n2, m12, ms2, s2, xs, _) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiDatatypeImp (x', n1, m11, ms1, s1, _, _) =>
                                      if x = x' then
@@ -2457,7 +2470,7 @@
                                    | _ => NONE)
 
                       | L'.SgiVal (x, n2, c2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiVal (x', n1, c1) =>
                                      if x = x' then
@@ -2474,7 +2487,7 @@
                                    | _ => NONE)
 
                       | L'.SgiStr (x, n2, sgn2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiStr (x', n1, sgn1) =>
                                      if x = x' then
@@ -2495,7 +2508,7 @@
                                    | _ => NONE)
 
                       | L'.SgiSgn (x, n2, sgn2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiSgn (x', n1, sgn1) =>
                                      if x = x' then
@@ -2516,7 +2529,7 @@
                                    | _ => NONE)
 
                       | L'.SgiConstraint (c2, d2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  case sgi1 of
                                      L'.SgiConstraint (c1, d1) =>
                                      if consEq (env, denv) (c1, c2) andalso consEq (env, denv) (d1, d2) then
@@ -2534,7 +2547,7 @@
                                    | _ => NONE)
 
                       | L'.SgiClassAbs (x, n2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      fun found (x', n1, co) =
                                          if x = x' then
@@ -2557,7 +2570,7 @@
                                        | _ => NONE
                                  end)
                       | L'.SgiClass (x, n2, c2) =>
-                        seek (fn sgi1All as (sgi1, _) =>
+                        seek (fn (env, sgi1All as (sgi1, _)) =>
                                  let
                                      val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
 
@@ -2606,7 +2619,7 @@
             subSgn (E.pushStrNamedAs env m2 n2 dom2, denv) ran1 ran2
         end
 
-      | _ => sgnError env (SgnWrongForm (sgn1, sgn2))
+      | _ => sgnError env (SgnWrongForm (sgn1, sgn2)))
 
 
 fun positive self =
@@ -2717,46 +2730,79 @@
 
                        | _ => NONE
 
-                 val (needed, constraints, _) =
-                     foldl (fn ((sgi, loc), (needed, constraints, env')) =>
+                 val (neededC, constraints, neededV, _) =
+                     foldl (fn ((sgi, loc), (neededC, constraints, neededV, env')) =>
                                let
-                                   val (needed, constraints) =
+                                   val (needed, constraints, neededV) =
                                        case sgi of
-                                           L'.SgiConAbs (x, _, _) => (SS.add (needed, x), constraints)
-                                         | L'.SgiConstraint cs => (needed, (env', cs, loc) :: constraints)
-                                         | _ => (needed, constraints)
+                                           L'.SgiConAbs (x, _, _) => (SS.add (neededC, x), constraints, neededV)
+                                         | L'.SgiConstraint cs => (neededC, (env', cs, loc) :: constraints, neededV)
+
+                                         | L'.SgiVal (x, _, t) =>
+                                           let
+                                               fun default () = (neededC, constraints, neededV)
+
+                                               val t = normClassConstraint env' t
+                                           in
+                                               case #1 t of
+                                                   L'.CApp (f, _) =>
+                                                   if E.isClass env' f then
+                                                       (neededC, constraints, SS.add (neededV, x))
+                                                   else
+                                                       default ()
+                                                       
+                                                 | _ => default ()
+                                           end
+
+                                         | _ => (neededC, constraints, neededV)
                                in
-                                   (needed, constraints, E.sgiBinds env' (sgi, loc))
+                                   (needed, constraints, neededV, E.sgiBinds env' (sgi, loc))
                                end)
-                           (SS.empty, [], env) sgis
+                           (SS.empty, [], SS.empty, env) sgis
                                                               
-                 val needed = foldl (fn ((d, _), needed) =>
-                                        case d of
-                                            L.DCon (x, _, _) => (SS.delete (needed, x)
-                                                                 handle NotFound =>
-                                                                        needed)
-                                          | L.DClass (x, _) => (SS.delete (needed, x)
-                                                                handle NotFound => needed)
-                                          | L.DOpen _ => SS.empty
-                                          | _ => needed)
-                                    needed ds
-
-                 val cds = List.mapPartial (fn (env', (c1, c2), loc) =>
+                 val (neededC, neededV) = foldl (fn ((d, _), needed as (neededC, neededV)) =>
+                                                    case d of
+                                                        L.DCon (x, _, _) => ((SS.delete (neededC, x), neededV)
+                                                                             handle NotFound =>
+                                                                                    needed)
+                                                      | L.DClass (x, _) => ((SS.delete (neededC, x), neededV)
+                                                                            handle NotFound => needed)
+                                                      | L.DVal (x, _, _) => ((neededC, SS.delete (neededV, x))
+                                                                             handle NotFound => needed)
+                                                      | L.DOpen _ => (SS.empty, SS.empty)
+                                                      | _ => needed)
+                                                (neededC, neededV) ds
+
+                 val ds' = List.mapPartial (fn (env', (c1, c2), loc) =>
                                                case (decompileCon env' c1, decompileCon env' c2) of
                                                    (SOME c1, SOME c2) =>
                                                    SOME (L.DConstraint (c1, c2), loc)
                                                  | _ => NONE) constraints
+
+                 val ds' =
+                     case SS.listItems neededV of
+                         [] => ds'
+                       | xs =>
+                         let
+                             val ewild = (L.EWild, #2 str)
+                             val ds'' = map (fn x => (L.DVal (x, NONE, ewild), #2 str)) xs
+                         in
+                             ds'' @ ds'
+                         end
+
+                 val ds' =
+                     case SS.listItems neededC of
+                         [] => ds'
+                       | xs =>
+                         let
+                             val kwild = (L.KWild, #2 str)
+                             val cwild = (L.CWild kwild, #2 str)
+                             val ds'' = map (fn x => (L.DCon (x, NONE, cwild), #2 str)) xs
+                         in
+                             ds'' @ ds'
+                         end
              in
-                 case SS.listItems needed of
-                     [] => (L.StrConst (ds @ cds), #2 str)
-                   | xs =>
-                     let
-                         val kwild = (L.KWild, #2 str)
-                         val cwild = (L.CWild kwild, #2 str)
-                         val ds' = map (fn x => (L.DCon (x, NONE, cwild), #2 str)) xs
-                     in
-                         (L.StrConst (ds @ ds' @ cds), #2 str)
-                     end
+                 (L.StrConst (ds @ ds'), #2 str)
              end
            | _ => str)
       | _ => str