diff src/elaborate.sml @ 210:f4033abd6ab1

Inferring sql_type's
author Adam Chlipala <adamc@hcoop.net>
date Sat, 16 Aug 2008 12:35:46 -0400
parents 1487c712eb12
children e86411f647c6
line wrap: on
line diff
--- a/src/elaborate.sml	Sat Aug 16 12:15:38 2008 -0400
+++ b/src/elaborate.sml	Sat Aug 16 12:35:46 2008 -0400
@@ -47,6 +47,8 @@
 structure SS = BinarySetFn(SK)
 structure SM = BinaryMapFn(SK)
 
+val basis_r = ref 0
+
 fun elabExplicitness e =
     case e of
         L.Explicit => L'.Explicit
@@ -862,9 +864,7 @@
     
 and unifyCons'' (env, denv) (c1All as (c1, loc)) (c2All as (c2, _)) =
     let
-        fun err f = (prefaces "unifyCons'' fails" [("c1All", p_con env c1All),
-                                                   ("c2All", p_con env c2All)];
-                     raise CUnify' (f (c1All, c2All)))
+        fun err f = raise CUnify' (f (c1All, c2All))
 
         fun isRecord () = unifyRecordCons (env, denv) (c1All, c2All)
     in
@@ -985,6 +985,7 @@
      | PatHasNoArg of ErrorMsg.span
      | Inexhaustive of ErrorMsg.span
      | DuplicatePatField of ErrorMsg.span * string
+     | SqlInfer of ErrorMsg.span * L'.con
 
 fun expError env err =
     case err of
@@ -1027,7 +1028,10 @@
         ErrorMsg.errorAt loc "Inexhaustive 'case'"
       | DuplicatePatField (loc, s) =>
         ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern")
-
+      | SqlInfer (loc, c) =>
+        (ErrorMsg.errorAt loc "Can't infer SQL-ness of type";
+         eprefaces' [("Type", p_con env c)])
+         
 fun checkCon (env, denv) e c1 c2 =
     unifyCons (env, denv) c1 c2
     handle CUnify (c1, c2, err) =>
@@ -1415,6 +1419,51 @@
                      ((L'.EModProj (n, ms, s), loc), t, [])
                  end)
 
+          | L.EApp (e1, (L.ESqlInfer, _)) =>
+            let
+                val (e1', t1, gs1) = elabExp (env, denv) e1
+                val (e1', t1, gs2) = elabHead (env, denv) e1' t1
+                val (t1, gs3) = hnormCon (env, denv) t1
+            in
+                case t1 of
+                    (L'.TFun ((L'.CApp ((L'.CModProj (basis, [], "sql_type"), _),
+                                        t), _), ran), _) =>
+                    if basis <> !basis_r then
+                        raise Fail "Bad use of ESqlInfer [1]"
+                    else
+                        let
+                            val (t, gs4) = hnormCon (env, denv) t
+
+                            fun error () = expError env (SqlInfer (loc, t))
+                        in
+                            case t of
+                                (L'.CModProj (basis, [], x), _) =>
+                                (if basis <> !basis_r then
+                                     error ()
+                                 else
+                                     case x of
+                                         "bool" => ()
+                                       | "int" => ()
+                                       | "float" => ()
+                                       | "string" => ()
+                                       | _ => error ();
+                                 ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_" ^ x), loc)), loc),
+                                  ran, gs1 @ gs2 @ gs3 @ gs4))
+                              | (L'.CUnif (_, (L'.KType, _), _, r), _) =>
+                                let
+                                    val t = (L'.CModProj (basis, [], "int"), loc)
+                                in
+                                    r := SOME t;
+                                    ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_int"), loc)), loc),
+                                     ran, gs1 @ gs2 @ gs3 @ gs4)
+                                end
+                              | _ => (error ();
+                                      (eerror, cerror, []))
+                        end
+                  | _ => raise Fail "Bad use of ESqlInfer [2]"
+            end
+          | L.ESqlInfer => raise Fail "Bad use of ESqlInfer [3]"
+
           | L.EApp (e1, e2) =>
             let
                 val (e1', t1, gs1) = elabExp (env, denv) e1
@@ -1736,12 +1785,7 @@
 
 val hnormSgn = E.hnormSgn
 
-fun tableOf' env =
-    case E.lookupStr env "Basis" of
-        NONE => raise Fail "Elaborate.tableOf: Can't find Basis"
-      | SOME (n, _) => n
-
-fun tableOf env = (L'.CModProj (tableOf' env, [], "sql_table"), ErrorMsg.dummySpan)
+fun tableOf () = (L'.CModProj (!basis_r, [], "sql_table"), ErrorMsg.dummySpan)
 
 fun elabSgn_item ((sgi, loc), (env, denv, gs)) =
     case sgi of
@@ -1911,10 +1955,10 @@
       | L.SgiTable (x, c) =>
         let
             val (c', k, gs) = elabCon (env, denv) c
-            val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc)
+            val (env, n) = E.pushENamed env x (L'.CApp (tableOf (), c'), loc)
         in
             checkKind env c' k (L'.KRecord (L'.KType, loc), loc);
-            ([(L'.SgiTable (tableOf' env, x, n, c'), loc)], (env, denv, gs))
+            ([(L'.SgiTable (!basis_r, x, n, c'), loc)], (env, denv, gs))
         end
 
 and elabSgn (env, denv) (sgn, loc) =
@@ -2114,7 +2158,7 @@
                                             | L'.SgiConstraint (c1, c2) =>
                                               (L'.DConstraint (c1, c2), loc)
                                             | L'.SgiTable (_, x, n, c) =>
-                                              (L'.DVal (x, n, (L'.CApp (tableOf env, c), loc),
+                                              (L'.DVal (x, n, (L'.CApp (tableOf (), c), loc),
                                                         (L'.EModProj (str, strs, x), loc)), loc)
                                   in
                                       (d, (E.declBinds env' d, denv'))
@@ -2363,7 +2407,7 @@
                                          NONE
                                    | L'.SgiTable (_, x', n1, c1) =>
                                      if x = x' then
-                                         (case unifyCons (env, denv) (L'.CApp (tableOf env, c1), loc) c2 of
+                                         (case unifyCons (env, denv) (L'.CApp (tableOf (), c1), loc) c2 of
                                               [] => SOME (env, denv)
                                             | _ => NONE)
                                          handle CUnify (c1, c2, err) =>
@@ -2799,10 +2843,10 @@
       | L.DTable (x, c) =>
         let
             val (c', k, gs') = elabCon (env, denv) c
-            val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc)
+            val (env, n) = E.pushENamed env x (L'.CApp (tableOf (), c'), loc)
         in
             checkKind env c' k (L'.KRecord (L'.KType, loc), loc);
-            ([(L'.DTable (tableOf' env, x, n, c'), loc)], (env, denv, gs' @ gs))
+            ([(L'.DTable (!basis_r, x, n, c'), loc)], (env, denv, gs' @ gs))
         end
 
 and elabStr (env, denv) (str, loc) =
@@ -2979,6 +3023,7 @@
                            raise Fail "Unresolved disjointness constraints in Basis")
 
         val (env', basis_n) = E.pushStrNamed env "Basis" sgn
+        val () = basis_r := basis_n
 
         val (ds, (env', _)) = dopen (env', D.empty) {str = basis_n, strs = [], sgn = sgn}