diff src/elaborate.sml @ 211:e86411f647c6

Initial type class support
author Adam Chlipala <adamc@hcoop.net>
date Sat, 16 Aug 2008 14:32:18 -0400
parents f4033abd6ab1
children 0343557355fc
line wrap: on
line diff
--- a/src/elaborate.sml	Sat Aug 16 12:35:46 2008 -0400
+++ b/src/elaborate.sml	Sat Aug 16 14:32:18 2008 -0400
@@ -985,7 +985,8 @@
      | PatHasNoArg of ErrorMsg.span
      | Inexhaustive of ErrorMsg.span
      | DuplicatePatField of ErrorMsg.span * string
-     | SqlInfer of ErrorMsg.span * L'.con
+     | Unresolvable of ErrorMsg.span * L'.con
+     | OutOfContext of ErrorMsg.span
 
 fun expError env err =
     case err of
@@ -1028,9 +1029,11 @@
         ErrorMsg.errorAt loc "Inexhaustive 'case'"
       | DuplicatePatField (loc, s) =>
         ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern")
-      | SqlInfer (loc, c) =>
-        (ErrorMsg.errorAt loc "Can't infer SQL-ness of type";
-         eprefaces' [("Type", p_con env c)])
+      | OutOfContext loc =>
+        ErrorMsg.errorAt loc "Type class wildcard occurs out of context"
+      | Unresolvable (loc, c) =>
+        (ErrorMsg.errorAt loc "Can't resolve type class instance";
+         eprefaces' [("Class constraint", p_con env c)])
          
 fun checkCon (env, denv) e c1 c2 =
     unifyCons (env, denv) c1 c2
@@ -1419,50 +1422,23 @@
                      ((L'.EModProj (n, ms, s), loc), t, [])
                  end)
 
-          | L.EApp (e1, (L.ESqlInfer, _)) =>
+          | L.EApp (e1, (L.EWild, _)) =>
             let
                 val (e1', t1, gs1) = elabExp (env, denv) e1
                 val (e1', t1, gs2) = elabHead (env, denv) e1' t1
                 val (t1, gs3) = hnormCon (env, denv) t1
             in
                 case t1 of
-                    (L'.TFun ((L'.CApp ((L'.CModProj (basis, [], "sql_type"), _),
-                                        t), _), ran), _) =>
-                    if basis <> !basis_r then
-                        raise Fail "Bad use of ESqlInfer [1]"
-                    else
-                        let
-                            val (t, gs4) = hnormCon (env, denv) t
-
-                            fun error () = expError env (SqlInfer (loc, t))
-                        in
-                            case t of
-                                (L'.CModProj (basis, [], x), _) =>
-                                (if basis <> !basis_r then
-                                     error ()
-                                 else
-                                     case x of
-                                         "bool" => ()
-                                       | "int" => ()
-                                       | "float" => ()
-                                       | "string" => ()
-                                       | _ => error ();
-                                 ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_" ^ x), loc)), loc),
-                                  ran, gs1 @ gs2 @ gs3 @ gs4))
-                              | (L'.CUnif (_, (L'.KType, _), _, r), _) =>
-                                let
-                                    val t = (L'.CModProj (basis, [], "int"), loc)
-                                in
-                                    r := SOME t;
-                                    ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_int"), loc)), loc),
-                                     ran, gs1 @ gs2 @ gs3 @ gs4)
-                                end
-                              | _ => (error ();
-                                      (eerror, cerror, []))
-                        end
-                  | _ => raise Fail "Bad use of ESqlInfer [2]"
+                    (L'.TFun (dom, ran), _) =>
+                    (case E.resolveClass env dom of
+                         NONE => (expError env (Unresolvable (loc, dom));
+                                  (eerror, cerror, []))
+                       | SOME pf => ((L'.EApp (e1', pf), loc), ran, gs1 @ gs2 @ gs3))
+                  | _ => (expError env (OutOfContext loc);
+                          (eerror, cerror, []))
             end
-          | L.ESqlInfer => raise Fail "Bad use of ESqlInfer [3]"
+          | L.EWild => (expError env (OutOfContext loc);
+                        (eerror, cerror, []))
 
           | L.EApp (e1, e2) =>
             let
@@ -1961,6 +1937,26 @@
             ([(L'.SgiTable (!basis_r, x, n, c'), loc)], (env, denv, gs))
         end
 
+      | L.SgiClassAbs x =>
+        let
+            val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
+            val (env, n) = E.pushCNamed env x k NONE
+            val env = E.pushClass env n
+        in
+            ([(L'.SgiClassAbs (x, n), loc)], (env, denv, []))
+        end
+
+      | L.SgiClass (x, c) =>
+        let
+            val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
+            val (c', ck, gs) = elabCon (env, denv) c
+            val (env, n) = E.pushCNamed env x k (SOME c')
+            val env = E.pushClass env n
+        in
+            checkKind env c' ck k;
+            ([(L'.SgiClass (x, n, c'), loc)], (env, denv, []))
+        end
+
 and elabSgn (env, denv) (sgn, loc) =
     case sgn of
         L.SgnConst sgis =>
@@ -2027,7 +2023,19 @@
                                        sgnError env (DuplicateVal (loc, x))
                                    else
                                        ();
-                                   (cons, SS.add (vals, x), sgns, strs)))
+                                   (cons, SS.add (vals, x), sgns, strs))
+                                | L'.SgiClassAbs (x, _) =>
+                                  (if SS.member (cons, x) then
+                                       sgnError env (DuplicateCon (loc, x))
+                                   else
+                                       ();
+                                   (SS.add (cons, x), vals, sgns, strs))
+                                | L'.SgiClass (x, _, _) =>
+                                  (if SS.member (cons, x) then
+                                       sgnError env (DuplicateCon (loc, x))
+                                   else
+                                       ();
+                                   (SS.add (cons, x), vals, sgns, strs)))
                     (SS.empty, SS.empty, SS.empty, SS.empty) sgis'
         in
             ((L'.SgnConst sgis', loc), gs)
@@ -2160,6 +2168,20 @@
                                             | L'.SgiTable (_, x, n, c) =>
                                               (L'.DVal (x, n, (L'.CApp (tableOf (), c), loc),
                                                         (L'.EModProj (str, strs, x), loc)), loc)
+                                            | L'.SgiClassAbs (x, n) =>
+                                              let
+                                                  val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
+                                                  val c = (L'.CModProj (str, strs, x), loc)
+                                              in
+                                                  (L'.DCon (x, n, k, c), loc)
+                                              end
+                                            | L'.SgiClass (x, n, _) =>
+                                              let
+                                                  val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
+                                                  val c = (L'.CModProj (str, strs, x), loc)
+                                              in
+                                                  (L'.DCon (x, n, k, c), loc)
+                                              end
                                   in
                                       (d, (E.declBinds env' d, denv'))
                                   end)
@@ -2283,27 +2305,41 @@
                                          in
                                              found (x', n1, k', SOME (L'.CModProj (m1, ms, s), loc))
                                          end
+                                       | L'.SgiClassAbs (x', n1) => found (x', n1,
+                                                                        (L'.KArrow ((L'.KType, loc),
+                                                                                    (L'.KType, loc)), loc),
+                                                                        NONE)
+                                       | L'.SgiClass (x', n1, c) => found (x', n1,
+                                                                           (L'.KArrow ((L'.KType, loc),
+                                                                                       (L'.KType, loc)), loc),
+                                                                           SOME c)
                                        | _ => NONE
                                  end)
 
                       | L'.SgiCon (x, n2, k2, c2) =>
                         seek (fn sgi1All as (sgi1, _) =>
-                                 case sgi1 of
-                                     L'.SgiCon (x', n1, k1, c1) =>
-                                     if x = x' then
-                                         let
-                                             fun good () = SOME (E.pushCNamedAs env x n2 k2 (SOME c2), denv)
-                                         in
-                                             (case unifyCons (env, denv) c1 c2 of
-                                                  [] => good ()
-                                                | _ => NONE)
-                                             handle CUnify (c1, c2, err) =>
-                                                    (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err));
-                                                     good ())
-                                         end
-                                     else
-                                         NONE
-                                   | _ => NONE)
+                                 let
+                                     fun found (x', n1, k1, c1) =
+                                         if x = x' then
+                                             let
+                                                 fun good () = SOME (E.pushCNamedAs env x n2 k2 (SOME c2), denv)
+                                             in
+                                                 (case unifyCons (env, denv) c1 c2 of
+                                                      [] => good ()
+                                                    | _ => NONE)
+                                                 handle CUnify (c1, c2, err) =>
+                                                        (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err));
+                                                         good ())
+                                             end
+                                         else
+                                             NONE
+                                 in
+                                     case sgi1 of
+                                         L'.SgiCon (x', n1, k1, c1) => found (x', n1, k1, c1)
+                                       | L'.SgiClass (x', n1, c1) =>
+                                         found (x', n1, (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc), c1)
+                                       | _ => NONE
+                                 end)
 
                       | L'.SgiDatatype (x, n2, xs2, xncs2) =>
                         seek (fn sgi1All as (sgi1, _) =>
@@ -2491,6 +2527,54 @@
                                      else
                                          NONE
                                    | _ => NONE)
+
+                      | L'.SgiClassAbs (x, n2) =>
+                        seek (fn sgi1All as (sgi1, _) =>
+                                 let
+                                     fun found (x', n1, co) =
+                                         if x = x' then
+                                             let
+                                                 val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
+                                                 val env = E.pushCNamedAs env x n1 k co
+                                             in
+                                                 SOME (if n1 = n2 then
+                                                           env
+                                                       else
+                                                           E.pushCNamedAs env x n2 k (SOME (L'.CNamed n1, loc2)),
+                                                       denv)
+                                             end
+                                         else
+                                             NONE
+                                 in
+                                     case sgi1 of
+                                         L'.SgiClassAbs (x', n1) => found (x', n1, NONE)
+                                       | L'.SgiClass (x', n1, c) => found (x', n1, SOME c)
+                                       | _ => NONE
+                                 end)
+                      | L'.SgiClass (x, n2, c2) =>
+                        seek (fn sgi1All as (sgi1, _) =>
+                                 let
+                                     val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
+
+                                     fun found (x', n1, c1) =
+                                         if x = x' then
+                                             let
+                                                 fun good () = SOME (E.pushCNamedAs env x n2 k (SOME c2), denv)
+                                             in
+                                                 (case unifyCons (env, denv) c1 c2 of
+                                                      [] => good ()
+                                                    | _ => NONE)
+                                                 handle CUnify (c1, c2, err) =>
+                                                        (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err));
+                                                         good ())
+                                             end
+                                         else
+                                             NONE
+                                 in
+                                     case sgi1 of
+                                         L'.SgiClass (x', n1, c1) => found (x', n1, c1)
+                                       | _ => NONE
+                                 end)
                 end
         in
             ignore (foldl folder (env, denv) sgis2)
@@ -2849,6 +2933,17 @@
             ([(L'.DTable (!basis_r, x, n, c'), loc)], (env, denv, gs' @ gs))
         end
 
+      | L.DClass (x, c) =>
+        let
+            val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
+            val (c', ck, gs) = elabCon (env, denv) c
+            val (env, n) = E.pushCNamed env x k (SOME c')
+            val env = E.pushClass env n
+        in
+            checkKind env c' ck k;
+            ([(L'.DCon (x, n, k, c'), loc)], (env, denv, []))
+        end
+
 and elabStr (env, denv) (str, loc) =
     case str of
         L.StrConst ds =>
@@ -2949,6 +3044,30 @@
                                           (SS.add (vals, x), x)
                               in
                                   ((L'.SgiTable (tn, x, n, c), loc) :: sgis, cons, vals, sgns, strs)
+                              end
+                            | L'.SgiClassAbs (x, n) =>
+                              let
+                                  val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
+
+                                  val (cons, x) =
+                                      if SS.member (cons, x) then
+                                          (cons, "?" ^ x)
+                                      else
+                                          (SS.add (cons, x), x)
+                              in
+                                  ((L'.SgiClassAbs (x, n), loc) :: sgis, cons, vals, sgns, strs)
+                              end
+                            | L'.SgiClass (x, n, c) =>
+                              let
+                                  val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc)
+
+                                  val (cons, x) =
+                                      if SS.member (cons, x) then
+                                          (cons, "?" ^ x)
+                                      else
+                                          (SS.add (cons, x), x)
+                              in
+                                  ((L'.SgiClass (x, n, c), loc) :: sgis, cons, vals, sgns, strs)
                               end)
 
                 ([], SS.empty, SS.empty, SS.empty, SS.empty) sgis