changeset 228:19e5791923d0

Resolving lingering type class constraints
author Adam Chlipala <adamc@hcoop.net>
date Thu, 21 Aug 2008 14:45:31 -0400
parents 524e10c91478
children 016d71e878c1
files src/elab.sml src/elab_env.sml src/elab_print.sml src/elab_util.sml src/elaborate.sml src/explify.sml tests/group_by.lac
diffstat 7 files changed, 67 insertions(+), 40 deletions(-) [+]
line wrap: on
line diff
--- a/src/elab.sml	Thu Aug 21 14:09:08 2008 -0400
+++ b/src/elab.sml	Thu Aug 21 14:45:31 2008 -0400
@@ -111,6 +111,7 @@
        | ECase of exp * (pat * exp) list * { disc : con, result : con }
 
        | EError
+       | EUnif of exp option ref
 
 withtype exp = exp' located
 
--- a/src/elab_env.sml	Thu Aug 21 14:09:08 2008 -0400
+++ b/src/elab_env.sml	Thu Aug 21 14:45:31 2008 -0400
@@ -363,6 +363,7 @@
     case c of
         CNamed n => SOME (ClNamed n)
       | CModProj x => SOME (ClProj x)
+      | CUnif (_, _, _, ref (SOME c)) => class_name_in c
       | _ => NONE
 
 fun class_key_in (c, _) =
@@ -370,6 +371,7 @@
         CRel n => SOME (CkRel n)
       | CNamed n => SOME (CkNamed n)
       | CModProj x => SOME (CkProj x)
+      | CUnif (_, _, _, ref (SOME c)) => class_key_in c
       | _ => NONE
 
 fun class_pair_in (c, _) =
--- a/src/elab_print.sml	Thu Aug 21 14:09:08 2008 -0400
+++ b/src/elab_print.sml	Thu Aug 21 14:45:31 2008 -0400
@@ -363,6 +363,8 @@
                                                                              p_exp env e]) pes])
 
       | EError => string "<ERROR>"
+      | EUnif (ref (SOME e)) => p_exp env e
+      | EUnif _ => string "_"
 
 and p_exp env = p_exp' false env
 
--- a/src/elab_util.sml	Thu Aug 21 14:09:08 2008 -0400
+++ b/src/elab_util.sml	Thu Aug 21 14:45:31 2008 -0400
@@ -347,6 +347,8 @@
                                                         (ECase (e', pes', {disc = disc', result = result'}), loc)))))
 
               | EError => S.return2 eAll
+              | EUnif (ref (SOME e)) => mfe ctx e
+              | EUnif _ => S.return2 eAll
     in
         mfe
     end
--- a/src/elaborate.sml	Thu Aug 21 14:09:08 2008 -0400
+++ b/src/elaborate.sml	Thu Aug 21 14:45:31 2008 -0400
@@ -1123,6 +1123,12 @@
                                               (L'.CApp ((L'.CRel 1, loc), (L'.CRel 0, loc)), loc)), loc)),
                           loc)), loc)), loc)
 
+datatype constraint =
+         Disjoint of D.goal
+       | TypeClass of E.env * L'.con * L'.exp option ref * ErrorMsg.span
+
+val enD = map Disjoint
+
 fun elabHead (env, denv) (e as (_, loc)) t =
     let
         fun unravel (t, e) =
@@ -1137,9 +1143,9 @@
                         val (e, t, gs') = unravel (subConInCon (0, u) t',
                                                    (L'.ECApp (e, u), loc))
                     in
-                        (e, t, gs @ gs')
+                        (e, t, enD gs @ gs')
                     end
-                  | _ => (e, t, gs)
+                  | _ => (e, t, enD gs)
             end
     in
         unravel (t, e)
@@ -1462,7 +1468,7 @@
                 val (t', _, gs2) = elabCon (env, denv) t
                 val gs3 = checkCon (env, denv) e' et t'
             in
-                (e', t', gs1 @ gs2 @ gs3)
+                (e', t', gs1 @ enD gs2 @ enD gs3)
             end
 
           | L.EPrim p => ((L'.EPrim p, loc), primType env p, [])
@@ -1510,9 +1516,13 @@
                         val (dom, gs4) = normClassConstraint (env, denv) dom
                     in
                         case E.resolveClass env dom of
-                            NONE => (expError env (Unresolvable (loc, dom));
-                                     (eerror, cerror, []))
-                          | SOME pf => ((L'.EApp (e1', pf), loc), ran, gs1 @ gs2 @ gs3 @ gs4)
+                            NONE =>
+                            let
+                                val r = ref NONE
+                            in
+                                ((L'.EUnif r, loc), ran, [TypeClass (env, dom, r, loc)])
+                            end
+                          | SOME pf => ((L'.EApp (e1', pf), loc), ran, gs1 @ gs2 @ enD gs3 @ enD gs4)
                     end
                   | _ => (expError env (OutOfContext (loc, SOME (e1', t1)));
                           (eerror, cerror, []))
@@ -1533,7 +1543,7 @@
                 val gs4 = checkCon (env, denv) e1' t1 t
                 val gs5 = checkCon (env, denv) e2' t2 dom
 
-                val gs = gs1 @ gs2 @ gs3 @ gs4 @ gs5
+                val gs = gs1 @ gs2 @ gs3 @ enD gs4 @ enD gs5
             in
                 ((L'.EApp (e1', e2'), loc), ran, gs)
             end
@@ -1552,7 +1562,7 @@
             in
                 ((L'.EAbs (x, t', et, e'), loc),
                  (L'.TFun (t', et), loc),
-                 gs1 @ gs2)
+                 enD gs1 @ gs2)
             end
           | L.ECApp (e, c) =>
             let
@@ -1570,7 +1580,7 @@
                             handle SynUnif => (expError env (Unif ("substitution", eb));
                                                cerror)
                     in
-                        ((L'.ECApp (e', c'), loc), eb', gs1 @ gs2 @ gs3 @ gs4)
+                        ((L'.ECApp (e', c'), loc), eb', gs1 @ gs2 @ enD gs3 @ enD gs4)
                     end
 
                   | L'.CUnif _ =>
@@ -1606,7 +1616,7 @@
                 checkKind env c1' k1 (L'.KRecord ku1, loc);
                 checkKind env c2' k2 (L'.KRecord ku2, loc);
 
-                (e', (L'.TDisjoint (c1', c2', t), loc), gs1 @ gs2 @ gs3 @ gs4)
+                (e', (L'.TDisjoint (c1', c2', t), loc), enD gs1 @ enD gs2 @ enD gs3 @ gs4)
             end
 
           | L.ERecord xes =>
@@ -1617,7 +1627,7 @@
                                                            val (e', et, gs2) = elabExp (env, denv) e
                                                        in
                                                            checkKind env x' xk kname;
-                                                           ((x', e', et), gs1 @ gs2 @ gs)
+                                                           ((x', e', et), enD gs1 @ gs2 @ gs)
                                                        end)
                                                    [] xes
 
@@ -1641,10 +1651,13 @@
                         in
                             prove (rest, gs)
                         end
+
+                val gsD = List.mapPartial (fn Disjoint d => SOME d | _ => NONE) gs
+                val gsO = List.filter (fn Disjoint _ => false | _ => true) gs
             in
                 ((L'.ERecord xes', loc),
                  (L'.TRecord (L'.CRecord (ktype, map (fn (x', _, et) => (x', et)) xes'), loc), loc),
-                 prove (xes', gs))
+                 enD (prove (xes', gsD)) @ gsO)
             end
 
           | L.EField (e, c) =>
@@ -1661,7 +1674,7 @@
                              (L'.TRecord (L'.CConcat (first, rest), loc), loc)
                 val gs4 = D.prove env denv (first, rest, loc)
             in
-                ((L'.EField (e', c', {field = ft, rest = rest}), loc), ft, gs1 @ gs2 @ gs3 @ gs4)
+                ((L'.EField (e', c', {field = ft, rest = rest}), loc), ft, gs1 @ enD gs2 @ enD gs3 @ enD gs4)
             end
 
           | L.ECut (e, c) =>
@@ -1678,7 +1691,8 @@
                              (L'.TRecord (L'.CConcat (first, rest), loc), loc)
                 val gs4 = D.prove env denv (first, rest, loc)
             in
-                ((L'.ECut (e', c', {field = ft, rest = rest}), loc), (L'.TRecord rest, loc), gs1 @ gs2 @ gs3 @ gs4)
+                ((L'.ECut (e', c', {field = ft, rest = rest}), loc), (L'.TRecord rest, loc),
+                 gs1 @ enD gs2 @ enD gs3 @ enD gs4)
             end
 
           | L.EFold =>
@@ -1701,7 +1715,7 @@
                                          val (e', et, gs2) = elabExp (env, denv) e
                                          val gs3 = checkCon (env, denv) e' et result
                                      in
-                                         ((p', e'), gs1 @ gs2 @ gs3 @ gs)
+                                         ((p', e'), enD gs1 @ gs2 @ enD gs3 @ gs)
                                      end)
                                  gs1 pes
 
@@ -1712,7 +1726,7 @@
                 else
                     expError env (Inexhaustive loc);
 
-                ((L'.ECase (e', pes', {disc = et, result = result}), loc), result, gs' @ gs)
+                ((L'.ECase (e', pes', {disc = et, result = result}), loc), result, enD gs' @ gs)
             end
     end
             
@@ -2688,7 +2702,7 @@
       | _ => sgnError env (SgnWrongForm (sgn1, sgn2))
 
 
-fun elabDecl ((d, loc), (env, denv, gs)) =
+fun elabDecl ((d, loc), (env, denv, gs : constraint list)) =
     case d of
         L.DCon (x, ko, c) =>
         let
@@ -2701,7 +2715,7 @@
         in
             checkKind env c' ck k';
 
-            ([(L'.DCon (x, n, k', c'), loc)], (env', denv, gs' @ gs))
+            ([(L'.DCon (x, n, k', c'), loc)], (env', denv, enD gs' @ gs))
         end
       | L.DDatatype (x, xs, xcs) =>
         let
@@ -2727,7 +2741,7 @@
                                                val (t', tk, gs') = elabCon (env', denv') t'
                                            in
                                                checkKind env' t' tk k;
-                                               (SOME t', (L'.TFun (t', t), loc), gs' @ gs)
+                                               (SOME t', (L'.TFun (t', t), loc), enD gs' @ gs)
                                            end
                         val t = foldr (fn (x, t) => (L'.TCFun (L'.Implicit, x, k, t), loc)) t xs
 
@@ -2762,7 +2776,7 @@
                                   ((L'.StrVar n, loc), sgn) ms
              in
                  case hnormCon (env, denv) (L'.CModProj (n, ms, s), loc) of
-                     ((L'.CModProj (n, ms, s), _), gs) =>
+                     ((L'.CModProj (n, ms, s), _), gs') =>
                      (case E.projectDatatype env {sgn = sgn, str = str, field = s} of
                           NONE => (conError env (UnboundDatatype (loc, s));
                                    ([], (env, denv, gs)))
@@ -2788,7 +2802,7 @@
                                                       E.pushENamedAs env x n t
                                                   end) env xncs
                           in
-                              ([(L'.DDatatypeImp (x, n', n, ms, s, xs, xncs), loc)], (env, denv, gs))
+                              ([(L'.DDatatypeImp (x, n', n, ms, s, xs, xncs), loc)], (env, denv, enD gs' @ gs))
                           end)
                    | _ => (strError env (NotDatatype loc);
                            ([], (env, denv, [])))
@@ -2807,7 +2821,7 @@
         in
             (*prefaces "DVal" [("x", Print.PD.string x),
                              ("c'", p_con env c')];*)
-            ([(L'.DVal (x, n, c', e'), loc)], (env', denv, gs1 @ gs2 @ gs3 @ gs4 @ gs))
+            ([(L'.DVal (x, n, c', e'), loc)], (env', denv, enD gs1 @ gs2 @ enD gs3 @ enD gs4 @ gs))
         end
       | L.DValRec vis =>
         let
@@ -2818,7 +2832,7 @@
                                                                NONE => (cunif (loc, ktype), ktype, [])
                                                              | SOME c => elabCon (env, denv) c
                                     in
-                                        ((x, c', e), gs1 @ gs)
+                                        ((x, c', e), enD gs1 @ gs)
                                     end) [] vis
 
             val (vis, env) = ListUtil.foldlMap (fn ((x, c', e), env) =>
@@ -2834,7 +2848,7 @@
                                                                           
                                                       val gs2 = checkCon (env, denv) e' et c'
                                                   in
-                                                      ((x, n, c', e'), gs1 @ gs2 @ gs)
+                                                      ((x, n, c', e'), gs1 @ enD gs2 @ gs)
                                                   end) gs vis
         in
             ([(L'.DValRec vis, loc)], (env, denv, gs))
@@ -2845,7 +2859,7 @@
             val (sgn', gs') = elabSgn (env, denv) sgn
             val (env', n) = E.pushSgnNamed env x sgn'
         in
-            ([(L'.DSgn (x, n, sgn'), loc)], (env', denv, gs' @ gs))
+            ([(L'.DSgn (x, n, sgn'), loc)], (env', denv, enD gs' @ gs))
         end
 
       | L.DStr (x, sgno, str) =>
@@ -2906,7 +2920,7 @@
                         val (str', actual, gs2) = elabStr (env, denv) str
                     in
                         subSgn (env, denv) (selfifyAt env {str = str', sgn = actual}) formal;
-                        (str', formal, gs1 @ gs2)
+                        (str', formal, enD gs1 @ gs2)
                     end
 
             val (env', n) = E.pushStrNamed env x sgn'
@@ -2927,7 +2941,7 @@
 
             val (env', n) = E.pushStrNamed env x sgn'
         in
-            ([(L'.DFfiStr (x, n, sgn'), loc)], (env', denv, gs' @ gs))
+            ([(L'.DFfiStr (x, n, sgn'), loc)], (env', denv, enD gs' @ gs))
         end
 
       | L.DOpen (m, ms) =>
@@ -2960,7 +2974,7 @@
             checkKind env c1' k1 (L'.KRecord (kunif loc), loc);
             checkKind env c2' k2 (L'.KRecord (kunif loc), loc);
 
-            ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3 @ gs4 @ gs))
+            ([(L'.DConstraint (c1', c2'), loc)], (env, denv', enD gs1 @ enD gs2 @ enD gs3 @ enD gs4 @ gs))
         end
 
       | L.DOpenConstraints (m, ms) =>
@@ -3027,7 +3041,7 @@
             val (env, n) = E.pushENamed env x (L'.CApp (tableOf (), c'), loc)
         in
             checkKind env c' k (L'.KRecord (L'.KType, loc), loc);
-            ([(L'.DTable (!basis_r, x, n, c'), loc)], (env, denv, gs' @ gs))
+            ([(L'.DTable (!basis_r, x, n, c'), loc)], (env, denv, enD gs' @ gs))
         end
 
       | L.DClass (x, c) =>
@@ -3205,7 +3219,7 @@
         in
             ((L'.StrFun (m, n, dom', formal, str'), loc),
              (L'.SgnFun (m, n, dom', formal), loc),
-             gs1 @ gs2 @ gs3)
+             enD gs1 @ gs2 @ enD gs3)
         end
       | L.StrApp (str1, str2) =>
         let
@@ -3282,15 +3296,19 @@
         if ErrorMsg.anyErrors () then
             ()
         else
-            app (fn (loc, env, denv, c1, c2) =>
-                    case D.prove env denv (c1, c2, loc) of
-                        [] => ()
-                      | _ =>
-                        (ErrorMsg.errorAt loc "Couldn't prove field name disjointness";
-                         eprefaces' [("Con 1", p_con env c1),
-                                     ("Con 2", p_con env c2),
-                                     ("Hnormed 1", p_con env (ElabOps.hnormCon env c1)),
-                                     ("Hnormed 2", p_con env (ElabOps.hnormCon env c2))])) gs;
+            app (fn Disjoint (loc, env, denv, c1, c2) =>
+                    (case D.prove env denv (c1, c2, loc) of
+                         [] => ()
+                       | _ =>
+                         (ErrorMsg.errorAt loc "Couldn't prove field name disjointness";
+                          eprefaces' [("Con 1", p_con env c1),
+                                      ("Con 2", p_con env c2),
+                                      ("Hnormed 1", p_con env (ElabOps.hnormCon env c1)),
+                                      ("Hnormed 2", p_con env (ElabOps.hnormCon env c2))]))
+                  | TypeClass (env, c, r, loc) =>
+                    case E.resolveClass env c of
+                        SOME e => r := SOME e
+                      | NONE => expError env (Unresolvable (loc, c))) gs;
 
         (L'.DFfiStr ("Basis", basis_n, sgn), ErrorMsg.dummySpan) :: ds @ file
     end
--- a/src/explify.sml	Thu Aug 21 14:09:08 2008 -0400
+++ b/src/explify.sml	Thu Aug 21 14:45:31 2008 -0400
@@ -112,6 +112,8 @@
                    {disc = explifyCon disc, result = explifyCon result}), loc)
 
       | L.EError => raise Fail ("explifyExp: EError at " ^ EM.spanToString loc)
+      | L.EUnif (ref (SOME e)) => explifyExp e
+      | L.EUnif _ => raise Fail ("explifyExp: Undetermined EUnif at " ^ EM.spanToString loc)
 
 fun explifySgi (sgi, loc) =
     case sgi of
--- a/tests/group_by.lac	Thu Aug 21 14:09:08 2008 -0400
+++ b/tests/group_by.lac	Thu Aug 21 14:45:31 2008 -0400
@@ -8,4 +8,4 @@
 val q4 = (SELECT * FROM t1 WHERE t1.A = 0 GROUP BY t1.C HAVING t1.C < 0.2)
 
 val q5 = (SELECT t1.A, t2.D FROM t1, t2 GROUP BY t2.D, t1.A)
-val q6 = (SELECT t1.A, t2.D FROM t1, t2 WHERE t1.C = 0.0 GROUP BY t2.D, t1.A HAVING t1.A = 0 AND t2.D = 17)
+val q6 = (SELECT t1.A, t2.D FROM t1, t2 WHERE t1.C = 0.0 GROUP BY t2.D, t1.A HAVING t1.A = t1.A AND t2.D = 17)