changeset 205:cb8f69556975

Elaborating 'SELECT *' queries
author Adam Chlipala <adamc@hcoop.net>
date Thu, 14 Aug 2008 15:24:59 -0400
parents 241c9a0e3397
children cb8493759a7b
files src/elab.sml src/elab_env.sml src/elab_print.sml src/elab_util.sml src/elaborate.sml tests/table.lac
diffstat 6 files changed, 178 insertions(+), 70 deletions(-) [+]
line wrap: on
line diff
--- a/src/elab.sml	Thu Aug 14 13:59:11 2008 -0400
+++ b/src/elab.sml	Thu Aug 14 15:24:59 2008 -0400
@@ -119,7 +119,7 @@
        | SgiStr of string * int * sgn
        | SgiSgn of string * int * sgn
        | SgiConstraint of con * con
-       | SgiTable of string * int * con
+       | SgiTable of int * string * int * con
 
 and sgn' =
     SgnConst of sgn_item list
@@ -143,7 +143,7 @@
        | DFfiStr of string * int * sgn
        | DConstraint of con * con
        | DExport of int * sgn * str
-       | DTable of string * int * con
+       | DTable of int * string * int * con
 
      and str' =
          StrConst of decl list
--- a/src/elab_env.sml	Thu Aug 14 13:59:11 2008 -0400
+++ b/src/elab_env.sml	Thu Aug 14 15:24:59 2008 -0400
@@ -404,15 +404,12 @@
       | SgiSgn (x, n, sgn) => pushSgnNamedAs env x n sgn
       | SgiConstraint _ => env
 
-      | SgiTable (x, n, c) =>
-        (case lookupStr env "Basis" of
-             NONE => raise Fail "ElabEnv.sgiBinds: Can't find Basis"
-           | SOME (n, _) =>
-             let
-                 val t = (CApp ((CModProj (n, [], "table"), loc), c), loc)
-             in
-                 pushENamedAs env x n t
-             end)
+      | SgiTable (tn, x, n, c) =>
+        let
+            val t = (CApp ((CModProj (tn, [], "table"), loc), c), loc)
+        in
+            pushENamedAs env x n t
+        end
 
 fun sgnSeek f sgis =
     let
@@ -737,14 +734,11 @@
       | DFfiStr (x, n, sgn) => pushStrNamedAs env x n sgn
       | DConstraint _ => env
       | DExport _ => env
-      | DTable (x, n, c) =>
-        (case lookupStr env "Basis" of
-             NONE => raise Fail "ElabEnv.declBinds: Can't find Basis"
-           | SOME (n, _) =>
-             let
-                 val t = (CApp ((CModProj (n, [], "table"), loc), c), loc)
-             in
-                 pushENamedAs env x n t
-             end)
+      | DTable (tn, x, n, c) =>
+        let
+            val t = (CApp ((CModProj (tn, [], "table"), loc), c), loc)
+        in
+            pushENamedAs env x n t
+        end
 
 end
--- a/src/elab_print.sml	Thu Aug 14 13:59:11 2008 -0400
+++ b/src/elab_print.sml	Thu Aug 14 15:24:59 2008 -0400
@@ -447,13 +447,13 @@
                                        string "~",
                                        space,
                                        p_con env c2]
-      | SgiTable (x, n, c) => box [string "table",
-                                   space,
-                                   p_named x n,
-                                   space,
-                                   string ":",
-                                   space,
-                                   p_con env c]
+      | SgiTable (_, x, n, c) => box [string "table",
+                                      space,
+                                      p_named x n,
+                                      space,
+                                      string ":",
+                                      space,
+                                      p_con env c]
 
 and p_sgn env (sgn, _) =
     case sgn of
@@ -603,13 +603,13 @@
                                       string ":",
                                       space,
                                       p_sgn env sgn]
-      | DTable (x, n, c) => box [string "table",
-                                 space,
-                                 p_named x n,
-                                 space,
-                                 string ":",
-                                 space,
-                                 p_con env c]
+      | DTable (_, x, n, c) => box [string "table",
+                                    space,
+                                    p_named x n,
+                                    space,
+                                    string ":",
+                                    space,
+                                    p_con env c]
 
 and p_str env (str, _) =
     case str of
--- a/src/elab_util.sml	Thu Aug 14 13:59:11 2008 -0400
+++ b/src/elab_util.sml	Thu Aug 14 15:24:59 2008 -0400
@@ -436,10 +436,10 @@
                             S.map2 (con ctx c2,
                                     fn c2' =>
                                        (SgiConstraint (c1', c2'), loc)))
-              | SgiTable (x, n, c) =>
+              | SgiTable (tn, x, n, c) =>
                 S.map2 (con ctx c,
                         fn c' =>
-                           (SgiTable (x, n, c'), loc))
+                           (SgiTable (tn, x, n, c'), loc))
 
         and sg ctx s acc =
             S.bindP (sg' ctx s acc, sgn ctx)
@@ -600,7 +600,9 @@
                                                    bind (ctx, Str (x, sgn))
                                                  | DConstraint _ => ctx
                                                  | DExport _ => ctx
-                                                 | DTable _ => ctx,
+                                                 | DTable (tn, x, n, c) =>
+                                                   bind (ctx, NamedE (x, (CApp ((CModProj (n, [], "table"), loc),
+                                                                                c), loc))),
                                                mfd ctx d)) ctx ds,
                      fn ds' => (StrConst ds', loc))
               | StrVar _ => S.return2 strAll
@@ -688,10 +690,10 @@
                                     fn str' =>
                                        (DExport (en, sgn', str'), loc)))
 
-              | DTable (x, n, c) =>
+              | DTable (tn, x, n, c) =>
                 S.map2 (mfc ctx c,
                         fn c' =>
-                           (DTable (x, n, c'), loc))
+                           (DTable (tn, x, n, c'), loc))
 
         and mfvi ctx (x, n, c, e) =
             S.bind2 (mfc ctx c,
--- a/src/elaborate.sml	Thu Aug 14 13:59:11 2008 -0400
+++ b/src/elaborate.sml	Thu Aug 14 15:24:59 2008 -0400
@@ -558,6 +558,40 @@
 
 val hnormCon = D.hnormCon
 
+datatype con_summary =
+         Nil
+       | Cons
+       | Unknown
+
+fun compatible cs =
+    case cs of
+        (Unknown, _) => false
+      | (_, Unknown) => false
+      | (s1, s2) => s1 = s2
+
+fun summarizeCon (env, denv) c =
+    let
+        val (c, gs) = hnormCon (env, denv) c
+    in
+        case #1 c of
+            L'.CRecord (_, []) => (Nil, gs)
+          | L'.CRecord (_, _ :: _) => (Cons, gs)
+          | L'.CConcat ((L'.CRecord (_, _ :: _), _), _) => (Cons, gs)
+          | L'.CDisjoint (_, _, c) =>
+            let
+                val (s, gs') = summarizeCon (env, denv) c
+            in
+                (s, gs @ gs')
+            end
+          | _ => (Unknown, gs)
+    end
+
+fun p_con_summary s =
+    Print.PD.string (case s of
+                         Nil => "Nil"
+                       | Cons => "Cons"
+                       | Unknown => "Unknown")
+
 fun unifyRecordCons (env, denv) (c1, c2) =
     let
         fun rkindof c =
@@ -705,12 +739,77 @@
     let
         val (c1, gs1) = hnormCon (env, denv) c1
         val (c2, gs2) = hnormCon (env, denv) c2
-        val gs3 = unifyCons'' (env, denv) c1 c2
     in
-        gs1 @ gs2 @ gs3
+        let
+            val gs3 = unifyCons'' (env, denv) c1 c2
+        in
+            gs1 @ gs2 @ gs3
+        end
+        handle ex =>
+               let
+                   val loc = #2 c1
+
+                   fun unfold (dom, f, i, r, c) =
+                       let
+                           val nm = cunif (loc, (L'.KName, loc))
+                           val v = cunif (loc, dom)
+                           val rest = cunif (loc, (L'.KRecord dom, loc))
+
+                           val (iS, gs3) = summarizeCon (env, denv) i
+
+                           val app = (L'.CApp (f, nm), loc)
+                           val app = (L'.CApp (app, v), loc)
+                           val app = (L'.CApp (app, rest), loc)
+                           val (appS, gs4) = summarizeCon (env, denv) app
+
+                           val (cS, gs5) = summarizeCon (env, denv) c
+                       in
+                           (*prefaces "Summaries" [("iS", p_con_summary iS),
+                                                 ("appS", p_con_summary appS),
+                                                 ("cS", p_con_summary cS)];*)
+
+                           if compatible (iS, appS) then
+                               raise ex
+                           else if compatible (cS, iS) then
+                               let
+                                   (*val () = prefaces "Same?" [("i", p_con env i),
+                                                              ("c", p_con env c)]*)
+                                   val gs6 = unifyCons (env, denv) i c
+                                   (*val () = TextIO.print "Yes!\n"*)
+
+                                   val gs7 = unifyCons (env, denv) r (L'.CRecord (dom, []), loc)
+                               in
+                                   gs1 @ gs2 @ gs3 @ gs4 @ gs5 @ gs6 @ gs7
+                               end
+                           else if compatible (cS, appS) then
+                               let
+                                  (*val () = prefaces "Same?" [("app", p_con env app),
+                                                             ("c", p_con env c),
+                                                             ("app'", p_con env (#1 (hnormCon (env, denv) app)))]*)
+                                  val gs6 = unifyCons (env, denv) app c
+                                  (*val () = TextIO.print "Yes!\n"*)
+
+                                  val singleton = (L'.CRecord (dom, [(nm, v)]), loc)
+                                  val concat = (L'.CConcat (singleton, rest), loc)
+                                  val gs7 = unifyCons (env, denv) r concat
+                              in
+                                  (loc, env, denv, singleton, rest) :: gs1 @ gs2 @ gs3 @ gs4 @ gs5 @ gs6 @ gs7
+                              end
+                           else
+                             raise ex
+                       end
+                       handle _ => raise ex
+               in
+                   case (#1 c1, #1 c2) of
+                       (L'.CApp ((L'.CApp ((L'.CApp ((L'.CFold (dom, _), _), f), _), i), _), r), _) =>
+                       unfold (dom, f, i, r, c2)
+                     | (_, L'.CApp ((L'.CApp ((L'.CApp ((L'.CFold (dom, _), _), f), _), i), _), r)) =>
+                       unfold (dom, f, i, r, c1)
+                     | _ => raise ex
+               end
     end
     
-and unifyCons'' (env, denv) (c1All as (c1, _)) (c2All as (c2, _)) =
+and unifyCons'' (env, denv) (c1All as (c1, loc)) (c2All as (c2, _)) =
     let
         fun err f = raise CUnify' (f (c1All, c2All))
 
@@ -794,12 +893,12 @@
                 (r := SOME c1All;
                  [])
 
-
           | (L'.CFold (dom1, ran1), L'.CFold (dom2, ran2)) =>
             (unifyKinds dom1 dom2;
              unifyKinds ran1 ran2;
              [])
 
+
           | _ => err CIncompatible
     end
 
@@ -1264,8 +1363,10 @@
 
                 val gs4 = checkCon (env, denv) e1' t1 t
                 val gs5 = checkCon (env, denv) e2' t2 dom
+
+                val gs = gs1 @ gs2 @ gs3 @ gs4 @ gs5
             in
-                ((L'.EApp (e1', e2'), loc), ran, gs1 @ gs2 @ gs3 @ gs4 @ gs5)
+                ((L'.EApp (e1', e2'), loc), ran, gs)
             end
           | L.EAbs (x, to, e) =>
             let
@@ -1571,6 +1672,13 @@
 
 val hnormSgn = E.hnormSgn
 
+fun tableOf' env =
+    case E.lookupStr env "Basis" of
+        NONE => raise Fail "Elaborate.tableOf: Can't find Basis"
+      | SOME (n, _) => n
+
+fun tableOf env = (L'.CModProj (tableOf' env, [], "sql_table"), ErrorMsg.dummySpan)
+
 fun elabSgn_item ((sgi, loc), (env, denv, gs)) =
     case sgi of
         L.SgiConAbs (x, k) =>
@@ -1739,10 +1847,10 @@
       | L.SgiTable (x, c) =>
         let
             val (c', k, gs) = elabCon (env, denv) c
-            val (env, n) = E.pushENamed env x c'
+            val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc)
         in
             checkKind env c' k (L'.KRecord (L'.KType, loc), loc);
-            ([(L'.SgiTable (x, n, c'), loc)], (env, denv, gs))
+            ([(L'.SgiTable (tableOf' env, x, n, c'), loc)], (env, denv, gs))
         end
 
 and elabSgn (env, denv) (sgn, loc) =
@@ -1806,7 +1914,7 @@
                                        ();
                                    (cons, vals, sgns, SS.add (strs, x)))
                                 | L'.SgiConstraint _ => (cons, vals, sgns, strs)
-                                | L'.SgiTable (x, _, _) =>
+                                | L'.SgiTable (_, x, _, _) =>
                                   (if SS.member (vals, x) then
                                        sgnError env (DuplicateVal (loc, x))
                                    else
@@ -1910,11 +2018,6 @@
           | SOME (str, strs) => selfify env {sgn = sgn, str = str, strs = strs}
     end
 
-fun tableOf env =
-    case E.lookupStr env "Basis" of
-        NONE => raise Fail "Elaborate.tableOf: Can't find Basis"
-      | SOME (n, _) => (L'.CModProj (n, [], "sql_table"), ErrorMsg.dummySpan)
-
 fun dopen (env, denv) {str, strs, sgn} =
     let
         val m = foldl (fn (m, str) => (L'.StrProj (str, m), #2 sgn))
@@ -1946,7 +2049,7 @@
                                               (L'.DSgn (x, n, (L'.SgnProj (str, strs, x), loc)), loc)
                                             | L'.SgiConstraint (c1, c2) =>
                                               (L'.DConstraint (c1, c2), loc)
-                                            | L'.SgiTable (x, n, c) =>
+                                            | L'.SgiTable (_, x, n, c) =>
                                               (L'.DVal (x, n, (L'.CApp (tableOf env, c), loc),
                                                         (L'.EModProj (str, strs, x), loc)), loc)
                                   in
@@ -2001,7 +2104,7 @@
       | L'.DFfiStr (x, n, sgn) => [(L'.SgiStr (x, n, sgn), loc)]
       | L'.DConstraint cs => [(L'.SgiConstraint cs, loc)]
       | L'.DExport _ => []
-      | L'.DTable (x, n, c) => [(L'.SgiTable (x, n, c), loc)]
+      | L'.DTable (tn, x, n, c) => [(L'.SgiTable (tn, x, n, c), loc)]
 
 fun sgiBindsD (env, denv) (sgi, _) =
     case sgi of
@@ -2194,7 +2297,7 @@
                                                  SOME (env, denv))
                                      else
                                          NONE
-                                   | L'.SgiTable (x', n1, c1) =>
+                                   | L'.SgiTable (_, x', n1, c1) =>
                                      if x = x' then
                                          (case unifyCons (env, denv) (L'.CApp (tableOf env, c1), loc) c2 of
                                               [] => SOME (env, denv)
@@ -2266,10 +2369,10 @@
                                          NONE
                                    | _ => NONE)
 
-                      | L'.SgiTable (x, n2, c2) =>
+                      | L'.SgiTable (_, x, n2, c2) =>
                         seek (fn sgi1All as (sgi1, _) =>
                                  case sgi1 of
-                                     L'.SgiTable (x', n1, c1) =>
+                                     L'.SgiTable (_, x', n1, c1) =>
                                      if x = x' then
                                          (case unifyCons (env, denv) c1 c2 of
                                               [] => SOME (env, denv)
@@ -2541,7 +2644,7 @@
       | L.DOpen (m, ms) =>
         (case E.lookupStr env m of
              NONE => (strError env (UnboundStr (loc, m));
-                      ([], (env, denv, [])))
+                      ([], (env, denv, gs)))
            | SOME (n, sgn) =>
              let
                  val (_, sgn) = foldl (fn (m, (str, sgn)) =>
@@ -2554,7 +2657,7 @@
                  val (ds, (env', denv')) = dopen (env, denv) {str = n, strs = ms, sgn = sgn}
                  val denv' = dopenConstraints (loc, env', denv') {str = m, strs = ms}
              in
-                 (ds, (env', denv', []))
+                 (ds, (env', denv', gs))
              end)
 
       | L.DConstraint (c1, c2) =>
@@ -2568,19 +2671,19 @@
             checkKind env c1' k1 (L'.KRecord (kunif loc), loc);
             checkKind env c2' k2 (L'.KRecord (kunif loc), loc);
 
-            ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3 @ gs4))
+            ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3 @ gs4 @ gs))
         end
 
       | L.DOpenConstraints (m, ms) =>
         let
             val denv = dopenConstraints (loc, env, denv) {str = m, strs = ms}
         in
-            ([], (env, denv, []))
+            ([], (env, denv, gs))
         end
 
       | L.DExport str =>
         let
-            val (str', sgn, gs) = elabStr (env, denv) str
+            val (str', sgn, gs') = elabStr (env, denv) str
 
             val sgn =
                 case #1 (hnormSgn env sgn) of
@@ -2626,16 +2729,16 @@
                     end
                   | _ => sgn
         in
-            ([(L'.DExport (E.newNamed (), sgn, str'), loc)], (env, denv, gs))
+            ([(L'.DExport (E.newNamed (), sgn, str'), loc)], (env, denv, gs' @ gs))
         end
 
       | L.DTable (x, c) =>
         let
-            val (c', k, gs) = elabCon (env, denv) c
-            val (env, n) = E.pushENamed env x c'
+            val (c', k, gs') = elabCon (env, denv) c
+            val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc)
         in
             checkKind env c' k (L'.KRecord (L'.KType, loc), loc);
-            ([(L'.DTable (x, n, c'), loc)], (env, denv, gs))
+            ([(L'.DTable (tableOf' env, x, n, c'), loc)], (env, denv, gs' @ gs))
         end
 
 and elabStr (env, denv) (str, loc) =
@@ -2729,7 +2832,7 @@
                                   ((L'.SgiStr (x, n, sgn), loc) :: sgis, cons, vals, sgns, strs)
                               end
                             | L'.SgiConstraint _ => ((sgi, loc) :: sgis, cons, vals, sgns, strs)
-                            | L'.SgiTable (x, n, c) =>
+                            | L'.SgiTable (tn, x, n, c) =>
                               let
                                   val (vals, x) =
                                       if SS.member (vals, x) then
@@ -2737,7 +2840,7 @@
                                       else
                                           (SS.add (vals, x), x)
                               in
-                                  ((L'.SgiTable (x, n, c), loc) :: sgis, cons, vals, sgns, strs)
+                                  ((L'.SgiTable (tn, x, n, c), loc) :: sgis, cons, vals, sgns, strs)
                               end)
 
                 ([], SS.empty, SS.empty, SS.empty, SS.empty) sgis
--- a/tests/table.lac	Thu Aug 14 13:59:11 2008 -0400
+++ b/tests/table.lac	Thu Aug 14 15:24:59 2008 -0400
@@ -1,3 +1,12 @@
-table t : {A : int, B : string, C : float}
+table t1 : {A : int, B : string, C : float}
 
-val my_query = (SELECT * FROM t)
+val q1 = (SELECT * FROM t1)
+
+table t2 : {A : float, D : int}
+
+val q2 = (SELECT * FROM t1, t2)
+
+(*val q3 = (SELECT * FROM t1, t1)*)
+val q3 = (SELECT * FROM t1, t1 AS T2)
+
+val q4 = (SELECT * FROM {t1} AS T, t1 AS T2)