diff src/elaborate.sml @ 86:7f9bcc8bfa1e

More with disjointness assumptions
author Adam Chlipala <adamc@hcoop.net>
date Tue, 01 Jul 2008 13:19:14 -0400
parents 1f85890c9846
children 7bab29834cd6
line wrap: on
line diff
--- a/src/elaborate.sml	Tue Jul 01 12:25:12 2008 -0400
+++ b/src/elaborate.sml	Tue Jul 01 13:19:14 2008 -0400
@@ -545,42 +545,73 @@
       | L'.CError => kerror
       | L'.CUnif (_, k, _, _) => k
 
-fun unifyRecordCons env (c1, c2) =
+fun hnormCon (env, denv) c =
+    let
+        val cAll as (c, loc) = ElabOps.hnormCon env c
+
+        fun doDisj (c1, c2, c) =
+            let
+                val (c, gs) = hnormCon (env, denv) c
+            in
+                (c,
+                 map (fn cs => (loc, env, denv, cs)) (D.prove env denv (c1, c2, loc)) @ gs)
+            end
+    in
+        case c of
+            L'.CDisjoint cs => doDisj cs
+          | L'.TDisjoint cs => doDisj cs
+          | _ => (cAll, [])
+    end
+
+fun unifyRecordCons (env, denv) (c1, c2) =
     let
         val k1 = kindof env c1
         val k2 = kindof env c2
+
+        val (r1, gs1) = recordSummary (env, denv) c1
+        val (r2, gs2) = recordSummary (env, denv) c2
     in
         unifyKinds k1 k2;
-        unifySummaries env (k1, recordSummary env c1, recordSummary env c2)
+        unifySummaries (env, denv) (k1, r1, r2);
+        gs1 @ gs2
     end
 
-and recordSummary env c : record_summary =
-    case hnormCon env c of
-        (L'.CRecord (_, xcs), _) => {fields = xcs, unifs = [], others = []}
-      | (L'.CConcat (c1, c2), _) =>
-        let
-            val s1 = recordSummary env c1
-            val s2 = recordSummary env c2
-        in
-            {fields = #fields s1 @ #fields s2,
-             unifs = #unifs s1 @ #unifs s2,
-             others = #others s1 @ #others s2}
-        end
-      | (L'.CUnif (_, _, _, ref (SOME c)), _) => recordSummary env c
-      | c' as (L'.CUnif (_, _, _, r), _) => {fields = [], unifs = [(c', r)], others = []}
-      | c' => {fields = [], unifs = [], others = [c']}
+and recordSummary (env, denv) c =
+    let
+        val (c, gs) = hnormCon (env, denv) c
 
-and consEq env (c1, c2) =
-    (unifyCons env c1 c2;
-     true)
+        val (sum, gs') =
+            case c of
+                (L'.CRecord (_, xcs), _) => ({fields = xcs, unifs = [], others = []}, [])
+              | (L'.CConcat (c1, c2), _) =>
+                let
+                    val (s1, gs1) = recordSummary (env, denv) c1
+                    val (s2, gs2) = recordSummary (env, denv) c2
+                in
+                    ({fields = #fields s1 @ #fields s2,
+                      unifs = #unifs s1 @ #unifs s2,
+                      others = #others s1 @ #others s2},
+                     gs1 @ gs2)
+                end
+              | (L'.CUnif (_, _, _, ref (SOME c)), _) => recordSummary (env, denv) c
+              | c' as (L'.CUnif (_, _, _, r), _) => ({fields = [], unifs = [(c', r)], others = []}, [])
+              | c' => ({fields = [], unifs = [], others = [c']}, [])
+    in
+        (sum, gs @ gs')
+    end
+
+and consEq (env, denv) (c1, c2) =
+    (case unifyCons (env, denv) c1 c2 of
+         [] => true
+       | _ => false)
     handle CUnify _ => false
 
 and consNeq env (c1, c2) =
-    case (#1 (hnormCon env c1), #1 (hnormCon env c2)) of
+    case (#1 (ElabOps.hnormCon env c1), #1 (ElabOps.hnormCon env c2)) of
         (L'.CName x1, L'.CName x2) => x1 <> x2
       | _ => false
 
-and unifySummaries env (k, s1 : record_summary, s2 : record_summary) =
+and unifySummaries (env, denv) (k, s1 : record_summary, s2 : record_summary) =
     let
         (*val () = eprefaces "Summaries" [("#1", p_summary env s1),
                                           ("#2", p_summary env s2)]*)
@@ -609,13 +640,13 @@
 
         val (fs1, fs2) = eatMatching (fn ((x1, c1), (x2, c2)) =>
                                          not (consNeq env (x1, x2))
-                                         andalso consEq env (c1, c2)
-                                         andalso consEq env (x1, x2))
+                                         andalso consEq (env, denv) (c1, c2)
+                                         andalso consEq (env, denv) (x1, x2))
                                      (#fields s1, #fields s2)
         (*val () = eprefaces "Summaries2" [("#1", p_summary env {fields = fs1, unifs = #unifs s1, others = #others s1}),
                                            ("#2", p_summary env {fields = fs2, unifs = #unifs s2, others = #others s2})]*)
         val (unifs1, unifs2) = eatMatching (fn ((_, r1), (_, r2)) => r1 = r2) (#unifs s1, #unifs s2)
-        val (others1, others2) = eatMatching (consEq env) (#others s1, #others s2)
+        val (others1, others2) = eatMatching (consEq (env, denv)) (#others s1, #others s2)
 
         fun unifFields (fs, others, unifs) =
             case (fs, others, unifs) of
@@ -645,22 +676,19 @@
         val (fs1, others1, unifs2) = unifFields (fs1, others1, unifs2)
         val (fs2, others2, unifs1) = unifFields (fs2, others2, unifs1)
 
-        val clear1 = case (fs1, others1) of
-                         ([], []) => true
-                       | _ => false
-        val clear2 = case (fs2, others2) of
-                         ([], []) => true
+        val clear = case (fs1, others1, fs2, others2) of
+                         ([], [], [], []) => true
                        | _ => false
         val empty = (L'.CRecord (k, []), dummy)
         fun pairOffUnifs (unifs1, unifs2) =
             case (unifs1, unifs2) of
                 ([], _) =>
-                if clear1 then
+                if clear then
                     List.app (fn (_, r) => r := SOME empty) unifs2
                 else
                     raise CUnify' CRecordFailure
               | (_, []) =>
-                if clear2 then
+                if clear then
                     List.app (fn (_, r) => r := SOME empty) unifs1
                 else
                     raise CUnify' CRecordFailure
@@ -671,81 +699,89 @@
         pairOffUnifs (unifs1, unifs2)
     end
 
-
-and unifyCons' env c1 c2 =
-    unifyCons'' env (hnormCon env c1) (hnormCon env c2)
+and unifyCons' (env, denv) c1 c2 =
+    let
+        val (c1, gs1) = hnormCon (env, denv) c1
+        val (c2, gs2) = hnormCon (env, denv) c2
+    in
+        unifyCons'' (env, denv) c1 c2;
+        gs1 @ gs2
+    end
     
-and unifyCons'' env (c1All as (c1, _)) (c2All as (c2, _)) =
+and unifyCons'' (env, denv) (c1All as (c1, _)) (c2All as (c2, _)) =
     let
         fun err f = raise CUnify' (f (c1All, c2All))
 
-        fun isRecord () = unifyRecordCons env (c1All, c2All)
+        fun isRecord () = unifyRecordCons (env, denv) (c1All, c2All)
     in
         case (c1, c2) of
             (L'.TFun (d1, r1), L'.TFun (d2, r2)) =>
-            (unifyCons' env d1 d2;
-             unifyCons' env r1 r2)
+            unifyCons' (env, denv) d1 d2
+            @ unifyCons' (env, denv) r1 r2
           | (L'.TCFun (expl1, x1, d1, r1), L'.TCFun (expl2, _, d2, r2)) =>
             if expl1 <> expl2 then
                 err CExplicitness
             else
                 (unifyKinds d1 d2;
-                 unifyCons' (E.pushCRel env x1 d1) r1 r2)
-          | (L'.TRecord r1, L'.TRecord r2) => unifyCons' env r1 r2
+                 unifyCons' (E.pushCRel env x1 d1, D.enter denv) r1 r2)
+          | (L'.TRecord r1, L'.TRecord r2) => unifyCons' (env, denv) r1 r2
 
           | (L'.CRel n1, L'.CRel n2) =>
             if n1 = n2 then
-                ()
+                []
             else
                 err CIncompatible
           | (L'.CNamed n1, L'.CNamed n2) =>
             if n1 = n2 then
-                ()
+                []
             else
                 err CIncompatible
 
           | (L'.CApp (d1, r1), L'.CApp (d2, r2)) =>
-            (unifyCons' env d1 d2;
-             unifyCons' env r1 r2)
+            (unifyCons' (env, denv) d1 d2;
+             unifyCons' (env, denv) r1 r2)
           | (L'.CAbs (x1, k1, c1), L'.CAbs (_, k2, c2)) =>
             (unifyKinds k1 k2;
-             unifyCons' (E.pushCRel env x1 k1) c1 c2)
+             unifyCons' (E.pushCRel env x1 k1, D.enter denv) c1 c2)
 
           | (L'.CName n1, L'.CName n2) =>
             if n1 = n2 then
-                ()
+                []
             else
                 err CIncompatible
 
           | (L'.CModProj (n1, ms1, x1), L'.CModProj (n2, ms2, x2)) =>
             if n1 = n2 andalso ms1 = ms2 andalso x1 = x2 then
-                ()
+                []
             else
                 err CIncompatible
 
-          | (L'.CError, _) => ()
-          | (_, L'.CError) => ()
+          | (L'.CError, _) => []
+          | (_, L'.CError) => []
 
-          | (L'.CUnif (_, _, _, ref (SOME c1All)), _) => unifyCons' env c1All c2All
-          | (_, L'.CUnif (_, _, _, ref (SOME c2All))) => unifyCons' env c1All c2All
+          | (L'.CUnif (_, _, _, ref (SOME c1All)), _) => unifyCons' (env, denv) c1All c2All
+          | (_, L'.CUnif (_, _, _, ref (SOME c2All))) => unifyCons' (env, denv) c1All c2All
 
           | (L'.CUnif (_, k1, _, r1), L'.CUnif (_, k2, _, r2)) =>
             if r1 = r2 then
-                ()
+                []
             else
                 (unifyKinds k1 k2;
-                 r1 := SOME c2All)
+                 r1 := SOME c2All;
+                 [])
 
           | (L'.CUnif (_, _, _, r), _) =>
             if occursCon r c2All then
                 err COccursCheckFailed
             else
-                r := SOME c2All
+                (r := SOME c2All;
+                 [])
           | (_, L'.CUnif (_, _, _, r)) =>
             if occursCon r c1All then
                 err COccursCheckFailed
             else
-                r := SOME c1All
+                (r := SOME c1All;
+                 [])
 
           | (L'.CRecord _, _) => isRecord ()
           | (_, L'.CRecord _) => isRecord ()
@@ -754,13 +790,14 @@
 
           | (L'.CFold (dom1, ran1), L'.CFold (dom2, ran2)) =>
             (unifyKinds dom1 dom2;
-             unifyKinds ran1 ran2)
+             unifyKinds ran1 ran2;
+             [])
 
           | _ => err CIncompatible
     end
 
-and unifyCons env c1 c2 =
-    unifyCons' env c1 c2
+and unifyCons (env, denv) c1 c2 =
+    unifyCons' (env, denv) c1 c2
     handle CUnify' err => raise CUnify (c1, c2, err)
          | KUnify args => raise CUnify (c1, c2, CKind args)
 
@@ -791,10 +828,11 @@
          eprefaces' [("Expression", p_exp env e),
                      ("Type", p_con env t)])
 
-fun checkCon env e c1 c2 =
-    unifyCons env c1 c2
+fun checkCon (env, denv) e c1 c2 =
+    unifyCons (env, denv) c1 c2
     handle CUnify (c1, c2, err) =>
-           expError env (Unify (e, c1, c2, err))
+           (expError env (Unify (e, c1, c2, err));
+            [])
 
 fun primType env p =
     case p of
@@ -860,18 +898,24 @@
 
       | L'.EError => cerror
 
-fun elabHead env (e as (_, loc)) t =
+fun elabHead (env, denv) (e as (_, loc)) t =
     let
         fun unravel (t, e) =
-            case hnormCon env t of
-                (L'.TCFun (L'.Implicit, x, k, t'), _) =>
-                let
-                    val u = cunif (loc, k)
-                in
-                    unravel (subConInCon (0, u) t',
-                             (L'.ECApp (e, u), loc))
-                end
-              | _ => (e, t)
+            let
+                val (t, gs) = hnormCon (env, denv) t
+            in
+                case t of
+                    (L'.TCFun (L'.Implicit, x, k, t'), _) =>
+                    let
+                        val u = cunif (loc, k)
+
+                        val (e, t, gs') = unravel (subConInCon (0, u) t',
+                                                   (L'.ECApp (e, u), loc))
+                    in
+                        (e, t, gs @ gs')
+                    end
+                  | _ => (e, t, gs)
+            end
     in
         unravel (t, e)
     end
@@ -882,9 +926,9 @@
         let
             val (e', et, gs1) = elabExp (env, denv) e
             val (t', _, gs2) = elabCon (env, denv) t
+            val gs3 = checkCon (env, denv) e' et t'
         in
-            checkCon env e' et t';
-            (e', t', gs1 @ gs2)
+            (e', t', gs1 @ gs2 @ gs3)
         end
 
       | L.EPrim p => ((L'.EPrim p, loc), primType env p, [])
@@ -919,16 +963,17 @@
       | L.EApp (e1, e2) =>
         let
             val (e1', t1, gs1) = elabExp (env, denv) e1
-            val (e1', t1) = elabHead env e1' t1
-            val (e2', t2, gs2) = elabExp (env, denv) e2
+            val (e1', t1, gs2) = elabHead (env, denv) e1' t1
+            val (e2', t2, gs3) = elabExp (env, denv) e2
 
             val dom = cunif (loc, ktype)
             val ran = cunif (loc, ktype)
             val t = (L'.TFun (dom, ran), dummy)
+
+            val gs4 = checkCon (env, denv) e1' t1 t
+            val gs5 = checkCon (env, denv) e2' t2 dom
         in
-            checkCon env e1' t1 t;
-            checkCon env e2' t2 dom;
-            ((L'.EApp (e1', e2'), loc), ran, gs1 @ gs2)
+            ((L'.EApp (e1', e2'), loc), ran, gs1 @ gs2 @ gs3 @ gs4 @ gs5)
         end
       | L.EAbs (x, to, e) =>
         let
@@ -950,10 +995,11 @@
       | L.ECApp (e, c) =>
         let
             val (e', et, gs1) = elabExp (env, denv) e
-            val (e', et) = elabHead env e' et
-            val (c', ck, gs2) = elabCon (env, denv) c
+            val (e', et, gs2) = elabHead (env, denv) e' et
+            val (c', ck, gs3) = elabCon (env, denv) c
+            val ((et', _), gs4) = hnormCon (env, denv) et
         in
-            case #1 (hnormCon env et) of
+            case et' of
                 L'.CError => (eerror, cerror, [])
               | L'.TCFun (_, _, k, eb) =>
                 let
@@ -962,7 +1008,7 @@
                               handle SynUnif => (expError env (Unif ("substitution", eb));
                                                  cerror)
                 in
-                    ((L'.ECApp (e', c'), loc), eb', gs1 @ gs2)
+                    ((L'.ECApp (e', c'), loc), eb', gs1 @ gs2 @ gs3 @ gs4)
                 end
 
               | L'.CUnif _ =>
@@ -1012,10 +1058,32 @@
                                                        ((x', e', et), gs1 @ gs2 @ gs)
                                                    end)
                                                [] xes
+
+            val k = (L'.KType, loc)
+
+            fun prove (xets, gs) =
+                case xets of
+                    [] => gs
+                  | (x, _, t) :: rest =>
+                    let
+                        val xc = (x, t)
+                        val r1 = (L'.CRecord (k, [xc]), loc)
+                        val gs = foldl (fn ((x', _, t'), gs) =>
+                                           let
+                                               val xc' = (x', t')
+                                               val r2 = (L'.CRecord (k, [xc']), loc)
+                                           in
+                                               map (fn cs => (loc, env, denv, cs)) (D.prove env denv (r1, r2, loc))
+                                               @ gs
+                                           end)
+                                 gs rest
+                    in
+                        prove (rest, gs)
+                    end
         in
             ((L'.ERecord xes', loc),
              (L'.TRecord (L'.CRecord (ktype, map (fn (x', _, et) => (x', et)) xes'), loc), loc),
-             gs)
+             prove (xes', gs))
         end
 
       | L.EField (e, c) =>
@@ -1025,10 +1093,16 @@
 
             val ft = cunif (loc, ktype)
             val rest = cunif (loc, ktype_record)
+            val first = (L'.CRecord (ktype, [(c', ft)]), loc)
+                       
+            val gs3 =
+                checkCon (env, denv) e' et
+                         (L'.TRecord (L'.CConcat (first, rest), loc), loc)
+            val gs4 =
+                map (fn cs => (loc, env, denv, cs))
+                (D.prove env denv (first, rest, loc))
         in
-            checkKind env c' ck kname;
-            checkCon env e' et (L'.TRecord (L'.CConcat ((L'.CRecord (ktype, [(c', ft)]), loc), rest), loc), loc);
-            ((L'.EField (e', c', {field = ft, rest = rest}), loc), ft, gs1 @ gs2)
+            ((L'.EField (e', c', {field = ft, rest = rest}), loc), ft, gs1 @ gs2 @ gs3 @ gs4)
         end
 
       | L.EFold =>
@@ -1373,7 +1447,7 @@
       | L'.DStr (x, n, sgn, _) => (L'.SgiStr (x, n, sgn), loc)
       | L'.DFfiStr (x, n, sgn) => (L'.SgiStr (x, n, sgn), loc)
 
-fun subSgn env sgn1 (sgn2 as (_, loc2)) =
+fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) =
     case (#1 (hnormSgn env sgn1), #1 (hnormSgn env sgn2)) of
         (L'.SgnError, _) => ()
       | (_, L'.SgnError) => ()
@@ -1428,11 +1502,14 @@
                                      L'.SgiCon (x', n1, k1, c1) =>
                                      if x = x' then
                                          let
-                                             val () = unifyCons env c1 c2
-                                                 handle CUnify (c1, c2, err) =>
-                                                        sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err))
+                                             fun good () = SOME (E.pushCNamedAs env x n2 k2 (SOME c2))
                                          in
-                                             SOME (E.pushCNamedAs env x n2 k2 (SOME c2))
+                                             (case unifyCons (env, denv) c1 c2 of
+                                                  [] => good ()
+                                                | _ => NONE)
+                                             handle CUnify (c1, c2, err) =>
+                                                    (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err));
+                                                     good ())
                                          end
                                      else
                                          NONE
@@ -1443,13 +1520,12 @@
                                  case sgi1 of
                                      L'.SgiVal (x', n1, c1) =>
                                      if x = x' then
-                                         let
-                                             val () = unifyCons env c1 c2
-                                                 handle CUnify (c1, c2, err) =>
-                                                        sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err))
-                                         in
-                                             SOME env
-                                         end
+                                         (case unifyCons (env, denv) c1 c2 of
+                                              [] => SOME env
+                                            | _ => NONE)
+                                         handle CUnify (c1, c2, err) =>
+                                                (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err));
+                                                 SOME env)
                                      else
                                          NONE
                                    | _ => NONE)
@@ -1460,7 +1536,7 @@
                                      L'.SgiStr (x', n1, sgn1) =>
                                      if x = x' then
                                          let
-                                             val () = subSgn env sgn1 sgn2
+                                             val () = subSgn (env, denv) sgn1 sgn2
                                              val env = E.pushStrNamedAs env x n1 sgn1
                                              val env = if n1 = n2 then
                                                            env
@@ -1481,8 +1557,8 @@
                                      L'.SgiSgn (x', n1, sgn1) =>
                                      if x = x' then
                                          let
-                                             val () = subSgn env sgn1 sgn2
-                                             val () = subSgn env sgn2 sgn1
+                                             val () = subSgn (env, denv) sgn1 sgn2
+                                             val () = subSgn (env, denv) sgn2 sgn1
 
                                              val env = E.pushSgnNamedAs env x n2 sgn2
                                              val env = if n1 = n2 then
@@ -1508,8 +1584,8 @@
                 else
                     subStrInSgn (n1, n2) ran1
         in
-            subSgn env dom2 dom1;
-            subSgn (E.pushStrNamedAs env m2 n2 dom2) ran1 ran2
+            subSgn (env, denv) dom2 dom1;
+            subSgn (E.pushStrNamedAs env m2 n2 dom2, denv) ran1 ran2
         end
 
       | _ => sgnError env (SgnWrongForm (sgn1, sgn2))
@@ -1538,10 +1614,10 @@
 
             val (e', et, gs2) = elabExp (env, denv) e
             val (env', n) = E.pushENamed env x c'
+
+            val gs3 = checkCon (env, denv) e' et c'
         in
-            checkCon env e' et c';
-
-            ([(L'.DVal (x, n, c', e'), loc)], (env', gs1 @ gs2 @ gs))
+            ([(L'.DVal (x, n, c', e'), loc)], (env', gs1 @ gs2 @ gs3 @ gs))
         end
 
       | L.DSgn (x, sgn) =>
@@ -1602,7 +1678,7 @@
 
                         val (str', actual, gs2) = elabStr (env, denv) str
                     in
-                        subSgn env actual formal;
+                        subSgn (env, denv) actual formal;
                         (str', formal, gs1 @ gs2)
                     end
 
@@ -1739,7 +1815,7 @@
                     let
                         val (ran', gs) = elabSgn (env', denv) ran
                     in
-                        subSgn env' actual ran';
+                        subSgn (env', denv) actual ran';
                         (ran', gs)
                     end
         in
@@ -1755,7 +1831,7 @@
             case #1 (hnormSgn env sgn1) of
                 L'.SgnError => (strerror, sgnerror, [])
               | L'.SgnFun (m, n, dom, ran) =>
-                (subSgn env sgn2 dom;
+                (subSgn (env, denv) sgn2 dom;
                  case #1 (hnormSgn env ran) of
                      L'.SgnError => (strerror, sgnerror, [])
                    | L'.SgnConst sgis =>
@@ -1820,7 +1896,7 @@
                     case D.prove env denv (c1, c2, loc) of
                         [] => ()
                       | _ =>
-                        (ErrorMsg.errorAt loc "Remaining constraint";
+                        (ErrorMsg.errorAt loc "Couldn't prove field name disjointness";
                          eprefaces' [("Con 1", p_con env c1),
                                      ("Con 2", p_con env c2)])) gs;