diff src/elaborate.sml @ 205:cb8f69556975

Elaborating 'SELECT *' queries
author Adam Chlipala <adamc@hcoop.net>
date Thu, 14 Aug 2008 15:24:59 -0400
parents dd82457fda82
children cb8493759a7b
line wrap: on
line diff
--- a/src/elaborate.sml	Thu Aug 14 13:59:11 2008 -0400
+++ b/src/elaborate.sml	Thu Aug 14 15:24:59 2008 -0400
@@ -558,6 +558,40 @@
 
 val hnormCon = D.hnormCon
 
+datatype con_summary =
+         Nil
+       | Cons
+       | Unknown
+
+fun compatible cs =
+    case cs of
+        (Unknown, _) => false
+      | (_, Unknown) => false
+      | (s1, s2) => s1 = s2
+
+fun summarizeCon (env, denv) c =
+    let
+        val (c, gs) = hnormCon (env, denv) c
+    in
+        case #1 c of
+            L'.CRecord (_, []) => (Nil, gs)
+          | L'.CRecord (_, _ :: _) => (Cons, gs)
+          | L'.CConcat ((L'.CRecord (_, _ :: _), _), _) => (Cons, gs)
+          | L'.CDisjoint (_, _, c) =>
+            let
+                val (s, gs') = summarizeCon (env, denv) c
+            in
+                (s, gs @ gs')
+            end
+          | _ => (Unknown, gs)
+    end
+
+fun p_con_summary s =
+    Print.PD.string (case s of
+                         Nil => "Nil"
+                       | Cons => "Cons"
+                       | Unknown => "Unknown")
+
 fun unifyRecordCons (env, denv) (c1, c2) =
     let
         fun rkindof c =
@@ -705,12 +739,77 @@
     let
         val (c1, gs1) = hnormCon (env, denv) c1
         val (c2, gs2) = hnormCon (env, denv) c2
-        val gs3 = unifyCons'' (env, denv) c1 c2
     in
-        gs1 @ gs2 @ gs3
+        let
+            val gs3 = unifyCons'' (env, denv) c1 c2
+        in
+            gs1 @ gs2 @ gs3
+        end
+        handle ex =>
+               let
+                   val loc = #2 c1
+
+                   fun unfold (dom, f, i, r, c) =
+                       let
+                           val nm = cunif (loc, (L'.KName, loc))
+                           val v = cunif (loc, dom)
+                           val rest = cunif (loc, (L'.KRecord dom, loc))
+
+                           val (iS, gs3) = summarizeCon (env, denv) i
+
+                           val app = (L'.CApp (f, nm), loc)
+                           val app = (L'.CApp (app, v), loc)
+                           val app = (L'.CApp (app, rest), loc)
+                           val (appS, gs4) = summarizeCon (env, denv) app
+
+                           val (cS, gs5) = summarizeCon (env, denv) c
+                       in
+                           (*prefaces "Summaries" [("iS", p_con_summary iS),
+                                                 ("appS", p_con_summary appS),
+                                                 ("cS", p_con_summary cS)];*)
+
+                           if compatible (iS, appS) then
+                               raise ex
+                           else if compatible (cS, iS) then
+                               let
+                                   (*val () = prefaces "Same?" [("i", p_con env i),
+                                                              ("c", p_con env c)]*)
+                                   val gs6 = unifyCons (env, denv) i c
+                                   (*val () = TextIO.print "Yes!\n"*)
+
+                                   val gs7 = unifyCons (env, denv) r (L'.CRecord (dom, []), loc)
+                               in
+                                   gs1 @ gs2 @ gs3 @ gs4 @ gs5 @ gs6 @ gs7
+                               end
+                           else if compatible (cS, appS) then
+                               let
+                                  (*val () = prefaces "Same?" [("app", p_con env app),
+                                                             ("c", p_con env c),
+                                                             ("app'", p_con env (#1 (hnormCon (env, denv) app)))]*)
+                                  val gs6 = unifyCons (env, denv) app c
+                                  (*val () = TextIO.print "Yes!\n"*)
+
+                                  val singleton = (L'.CRecord (dom, [(nm, v)]), loc)
+                                  val concat = (L'.CConcat (singleton, rest), loc)
+                                  val gs7 = unifyCons (env, denv) r concat
+                              in
+                                  (loc, env, denv, singleton, rest) :: gs1 @ gs2 @ gs3 @ gs4 @ gs5 @ gs6 @ gs7
+                              end
+                           else
+                             raise ex
+                       end
+                       handle _ => raise ex
+               in
+                   case (#1 c1, #1 c2) of
+                       (L'.CApp ((L'.CApp ((L'.CApp ((L'.CFold (dom, _), _), f), _), i), _), r), _) =>
+                       unfold (dom, f, i, r, c2)
+                     | (_, L'.CApp ((L'.CApp ((L'.CApp ((L'.CFold (dom, _), _), f), _), i), _), r)) =>
+                       unfold (dom, f, i, r, c1)
+                     | _ => raise ex
+               end
     end
     
-and unifyCons'' (env, denv) (c1All as (c1, _)) (c2All as (c2, _)) =
+and unifyCons'' (env, denv) (c1All as (c1, loc)) (c2All as (c2, _)) =
     let
         fun err f = raise CUnify' (f (c1All, c2All))
 
@@ -794,12 +893,12 @@
                 (r := SOME c1All;
                  [])
 
-
           | (L'.CFold (dom1, ran1), L'.CFold (dom2, ran2)) =>
             (unifyKinds dom1 dom2;
              unifyKinds ran1 ran2;
              [])
 
+
           | _ => err CIncompatible
     end
 
@@ -1264,8 +1363,10 @@
 
                 val gs4 = checkCon (env, denv) e1' t1 t
                 val gs5 = checkCon (env, denv) e2' t2 dom
+
+                val gs = gs1 @ gs2 @ gs3 @ gs4 @ gs5
             in
-                ((L'.EApp (e1', e2'), loc), ran, gs1 @ gs2 @ gs3 @ gs4 @ gs5)
+                ((L'.EApp (e1', e2'), loc), ran, gs)
             end
           | L.EAbs (x, to, e) =>
             let
@@ -1571,6 +1672,13 @@
 
 val hnormSgn = E.hnormSgn
 
+fun tableOf' env =
+    case E.lookupStr env "Basis" of
+        NONE => raise Fail "Elaborate.tableOf: Can't find Basis"
+      | SOME (n, _) => n
+
+fun tableOf env = (L'.CModProj (tableOf' env, [], "sql_table"), ErrorMsg.dummySpan)
+
 fun elabSgn_item ((sgi, loc), (env, denv, gs)) =
     case sgi of
         L.SgiConAbs (x, k) =>
@@ -1739,10 +1847,10 @@
       | L.SgiTable (x, c) =>
         let
             val (c', k, gs) = elabCon (env, denv) c
-            val (env, n) = E.pushENamed env x c'
+            val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc)
         in
             checkKind env c' k (L'.KRecord (L'.KType, loc), loc);
-            ([(L'.SgiTable (x, n, c'), loc)], (env, denv, gs))
+            ([(L'.SgiTable (tableOf' env, x, n, c'), loc)], (env, denv, gs))
         end
 
 and elabSgn (env, denv) (sgn, loc) =
@@ -1806,7 +1914,7 @@
                                        ();
                                    (cons, vals, sgns, SS.add (strs, x)))
                                 | L'.SgiConstraint _ => (cons, vals, sgns, strs)
-                                | L'.SgiTable (x, _, _) =>
+                                | L'.SgiTable (_, x, _, _) =>
                                   (if SS.member (vals, x) then
                                        sgnError env (DuplicateVal (loc, x))
                                    else
@@ -1910,11 +2018,6 @@
           | SOME (str, strs) => selfify env {sgn = sgn, str = str, strs = strs}
     end
 
-fun tableOf env =
-    case E.lookupStr env "Basis" of
-        NONE => raise Fail "Elaborate.tableOf: Can't find Basis"
-      | SOME (n, _) => (L'.CModProj (n, [], "sql_table"), ErrorMsg.dummySpan)
-
 fun dopen (env, denv) {str, strs, sgn} =
     let
         val m = foldl (fn (m, str) => (L'.StrProj (str, m), #2 sgn))
@@ -1946,7 +2049,7 @@
                                               (L'.DSgn (x, n, (L'.SgnProj (str, strs, x), loc)), loc)
                                             | L'.SgiConstraint (c1, c2) =>
                                               (L'.DConstraint (c1, c2), loc)
-                                            | L'.SgiTable (x, n, c) =>
+                                            | L'.SgiTable (_, x, n, c) =>
                                               (L'.DVal (x, n, (L'.CApp (tableOf env, c), loc),
                                                         (L'.EModProj (str, strs, x), loc)), loc)
                                   in
@@ -2001,7 +2104,7 @@
       | L'.DFfiStr (x, n, sgn) => [(L'.SgiStr (x, n, sgn), loc)]
       | L'.DConstraint cs => [(L'.SgiConstraint cs, loc)]
       | L'.DExport _ => []
-      | L'.DTable (x, n, c) => [(L'.SgiTable (x, n, c), loc)]
+      | L'.DTable (tn, x, n, c) => [(L'.SgiTable (tn, x, n, c), loc)]
 
 fun sgiBindsD (env, denv) (sgi, _) =
     case sgi of
@@ -2194,7 +2297,7 @@
                                                  SOME (env, denv))
                                      else
                                          NONE
-                                   | L'.SgiTable (x', n1, c1) =>
+                                   | L'.SgiTable (_, x', n1, c1) =>
                                      if x = x' then
                                          (case unifyCons (env, denv) (L'.CApp (tableOf env, c1), loc) c2 of
                                               [] => SOME (env, denv)
@@ -2266,10 +2369,10 @@
                                          NONE
                                    | _ => NONE)
 
-                      | L'.SgiTable (x, n2, c2) =>
+                      | L'.SgiTable (_, x, n2, c2) =>
                         seek (fn sgi1All as (sgi1, _) =>
                                  case sgi1 of
-                                     L'.SgiTable (x', n1, c1) =>
+                                     L'.SgiTable (_, x', n1, c1) =>
                                      if x = x' then
                                          (case unifyCons (env, denv) c1 c2 of
                                               [] => SOME (env, denv)
@@ -2541,7 +2644,7 @@
       | L.DOpen (m, ms) =>
         (case E.lookupStr env m of
              NONE => (strError env (UnboundStr (loc, m));
-                      ([], (env, denv, [])))
+                      ([], (env, denv, gs)))
            | SOME (n, sgn) =>
              let
                  val (_, sgn) = foldl (fn (m, (str, sgn)) =>
@@ -2554,7 +2657,7 @@
                  val (ds, (env', denv')) = dopen (env, denv) {str = n, strs = ms, sgn = sgn}
                  val denv' = dopenConstraints (loc, env', denv') {str = m, strs = ms}
              in
-                 (ds, (env', denv', []))
+                 (ds, (env', denv', gs))
              end)
 
       | L.DConstraint (c1, c2) =>
@@ -2568,19 +2671,19 @@
             checkKind env c1' k1 (L'.KRecord (kunif loc), loc);
             checkKind env c2' k2 (L'.KRecord (kunif loc), loc);
 
-            ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3 @ gs4))
+            ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3 @ gs4 @ gs))
         end
 
       | L.DOpenConstraints (m, ms) =>
         let
             val denv = dopenConstraints (loc, env, denv) {str = m, strs = ms}
         in
-            ([], (env, denv, []))
+            ([], (env, denv, gs))
         end
 
       | L.DExport str =>
         let
-            val (str', sgn, gs) = elabStr (env, denv) str
+            val (str', sgn, gs') = elabStr (env, denv) str
 
             val sgn =
                 case #1 (hnormSgn env sgn) of
@@ -2626,16 +2729,16 @@
                     end
                   | _ => sgn
         in
-            ([(L'.DExport (E.newNamed (), sgn, str'), loc)], (env, denv, gs))
+            ([(L'.DExport (E.newNamed (), sgn, str'), loc)], (env, denv, gs' @ gs))
         end
 
       | L.DTable (x, c) =>
         let
-            val (c', k, gs) = elabCon (env, denv) c
-            val (env, n) = E.pushENamed env x c'
+            val (c', k, gs') = elabCon (env, denv) c
+            val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc)
         in
             checkKind env c' k (L'.KRecord (L'.KType, loc), loc);
-            ([(L'.DTable (x, n, c'), loc)], (env, denv, gs))
+            ([(L'.DTable (tableOf' env, x, n, c'), loc)], (env, denv, gs' @ gs))
         end
 
 and elabStr (env, denv) (str, loc) =
@@ -2729,7 +2832,7 @@
                                   ((L'.SgiStr (x, n, sgn), loc) :: sgis, cons, vals, sgns, strs)
                               end
                             | L'.SgiConstraint _ => ((sgi, loc) :: sgis, cons, vals, sgns, strs)
-                            | L'.SgiTable (x, n, c) =>
+                            | L'.SgiTable (tn, x, n, c) =>
                               let
                                   val (vals, x) =
                                       if SS.member (vals, x) then
@@ -2737,7 +2840,7 @@
                                       else
                                           (SS.add (vals, x), x)
                               in
-                                  ((L'.SgiTable (x, n, c), loc) :: sgis, cons, vals, sgns, strs)
+                                  ((L'.SgiTable (tn, x, n, c), loc) :: sgis, cons, vals, sgns, strs)
                               end)
 
                 ([], SS.empty, SS.empty, SS.empty, SS.empty) sgis