changeset 90:94ef20a31550

Fancier head normalization pushed inside of Disjoint
author Adam Chlipala <adamc@hcoop.net>
date Thu, 03 Jul 2008 11:04:25 -0400 (2008-07-03)
parents d3ee072fa609
children 4327abd52997
files src/disjoint.sig src/disjoint.sml src/elaborate.sml tests/cfold_disj.lac
diffstat 4 files changed, 148 insertions(+), 103 deletions(-) [+]
line wrap: on
line diff
--- a/src/disjoint.sig	Tue Jul 01 16:06:58 2008 -0400
+++ b/src/disjoint.sig	Thu Jul 03 11:04:25 2008 -0400
@@ -30,9 +30,14 @@
     type env
 
     val empty : env
-    val assert : ElabEnv.env -> env -> Elab.con * Elab.con -> env
     val enter : env -> env
 
-    val prove : ElabEnv.env -> env -> Elab.con * Elab.con * ErrorMsg.span -> (Elab.con * Elab.con) list
+    type goal = ErrorMsg.span * ElabEnv.env * env * Elab.con * Elab.con
+
+    val assert : ElabEnv.env -> env -> Elab.con * Elab.con -> env * goal list
+
+    val prove : ElabEnv.env -> env -> Elab.con * Elab.con * ErrorMsg.span -> goal list
+
+    val hnormCon : ElabEnv.env * env -> Elab.con -> Elab.con * goal list
 
 end
--- a/src/disjoint.sml	Tue Jul 01 16:06:58 2008 -0400
+++ b/src/disjoint.sml	Thu Jul 03 11:04:25 2008 -0400
@@ -109,6 +109,8 @@
 
 type env = PS.set PM.map
 
+type goal = ErrorMsg.span * ElabEnv.env * env * Elab.con * Elab.con
+
 val empty = PM.empty
 
 fun nameToRow (c, loc) =
@@ -128,55 +130,6 @@
          Piece of piece
        | Unknown of con
 
-fun decomposeRow env c =
-    let
-        fun decomposeName (c, acc) =
-            case #1 (hnormCon env c) of
-                CName s => Piece (NameC s) :: acc
-              | CRel n => Piece (NameR n) :: acc
-              | CNamed n => Piece (NameN n) :: acc
-              | CModProj (m1, ms, x) => Piece (NameM (m1, ms, x)) :: acc
-              | _ => Unknown c :: acc
-
-        fun decomposeRow (c, acc) =
-            case #1 (hnormCon env c) of
-                CRecord (_, xcs) => foldl (fn ((x, _), acc) => decomposeName (x, acc)) acc xcs
-              | CConcat (c1, c2) => decomposeRow (c1, decomposeRow (c2, acc))
-              | CRel n => Piece (RowR n) :: acc
-              | CNamed n => Piece (RowN n) :: acc
-              | CModProj (m1, ms, x) => Piece (RowM (m1, ms, x)) :: acc
-              | _ => Unknown c :: acc
-    in
-        decomposeRow (c, [])
-    end
-
-fun assert env denv (c1, c2) =
-    let
-        val ps1 = decomposeRow env c1
-        val ps2 = decomposeRow env c2
-
-        val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
-        val ps1 = unUnknown ps1
-        val ps2 = unUnknown ps2
-
-        (*val () = print "APieces1:\n"
-        val () = app pp ps1
-        val () = print "APieces2:\n"
-        val () = app pp ps2*)
-
-        fun assertPiece ps (p, denv) =
-            let
-                val pset = Option.getOpt (PM.find (denv, p), PS.empty)
-                val pset = PS.addList (pset, ps)
-            in
-                PM.insert (denv, p, pset)
-            end
-
-        val denv = foldl (assertPiece ps2) denv ps1
-    in
-        foldl (assertPiece ps1) denv ps2
-    end
-
 fun pieceEnter p =
     case p of
         NameR n => NameR (n + 1)
@@ -196,16 +149,79 @@
             NONE => false
           | SOME pset => PS.member (pset, p2)
 
-fun prove env denv (c1, c2, loc) =
+fun decomposeRow (env, denv) c =
     let
-        val ps1 = decomposeRow env c1
-        val ps2 = decomposeRow env c2
+        fun decomposeName (c, (acc, gs)) =
+            let
+                val (cAll as (c, _), gs') = hnormCon (env, denv) c
+
+                val acc = case c of
+                              CName s => Piece (NameC s) :: acc
+                            | CRel n => Piece (NameR n) :: acc
+                            | CNamed n => Piece (NameN n) :: acc
+                            | CModProj (m1, ms, x) => Piece (NameM (m1, ms, x)) :: acc
+                            | _ => Unknown cAll :: acc
+            in
+                (acc, gs' @ gs)
+            end
+
+        fun decomposeRow (c, (acc, gs)) =
+            let
+                val (cAll as (c, _), gs') = hnormCon (env, denv) c
+                val gs = gs' @ gs
+            in
+                case c of
+                    CRecord (_, xcs) => foldl (fn ((x, _), acc_gs) => decomposeName (x, acc_gs)) (acc, gs) xcs
+                  | CConcat (c1, c2) => decomposeRow (c1, decomposeRow (c2, (acc, gs)))
+                  | CRel n => (Piece (RowR n) :: acc, gs)
+                  | CNamed n => (Piece (RowN n) :: acc, gs)
+                  | CModProj (m1, ms, x) => (Piece (RowM (m1, ms, x)) :: acc, gs)
+                  | _ => (Unknown cAll :: acc, gs)
+            end
+    in
+        decomposeRow (c, ([], []))
+    end
+
+and assert env denv (c1, c2) =
+    let
+        val (ps1, gs1) = decomposeRow (env, denv) c1
+        val (ps2, gs2) = decomposeRow (env, denv) c2
+
+        val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
+        val ps1 = unUnknown ps1
+        val ps2 = unUnknown ps2
+
+        (*val () = print "APieces1:\n"
+        val () = app pp ps1
+        val () = print "APieces2:\n"
+        val () = app pp ps2*)
+
+        fun assertPiece ps (p, denv) =
+            let
+                val pset = Option.getOpt (PM.find (denv, p), PS.empty)
+                val ps = case p of
+                             NameC _ => List.filter (fn NameC _ => false | _ => true) ps
+                           | _ => ps
+                val pset = PS.addList (pset, ps)
+            in
+                PM.insert (denv, p, pset)
+            end
+
+        val denv = foldl (assertPiece ps2) denv ps1
+    in
+        (foldl (assertPiece ps1) denv ps2, gs1 @ gs2)
+    end
+
+and prove env denv (c1, c2, loc) =
+    let
+        val (ps1, gs1) = decomposeRow (env, denv) c1
+        val (ps2, gs2) = decomposeRow (env, denv) c2
 
         val hasUnknown = List.exists (fn Unknown _ => true | _ => false)
         val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
     in
         if hasUnknown ps1 orelse hasUnknown ps2 then
-            [(c1, c2)]
+            [(loc, env, denv, c1, c2)]
         else
             let
                 val ps1 = unUnknown ps1
@@ -222,9 +238,26 @@
                                     if prove1 denv (p1, p2) then
                                         rem
                                     else
-                                        (pieceToRow (p1, loc), pieceToRow (p2, loc)) :: rem) rem ps2)
-                      [] ps1
+                                        (loc, env, denv, pieceToRow (p1, loc), pieceToRow (p2, loc)) :: rem) rem ps2)
+                      (gs1 @ gs2) ps1
             end
     end
 
+and hnormCon (env, denv) c =
+    let
+        val cAll as (c, loc) = ElabOps.hnormCon env c
+
+        fun doDisj (c1, c2, c) =
+            let
+                val (c, gs) = hnormCon (env, denv) c
+            in
+                (c, prove env denv (c1, c2, loc) @ gs)
+            end
+    in
+        case c of
+            CDisjoint cs => doDisj cs
+          | TDisjoint cs => doDisj cs
+          | _ => (cAll, [])
+    end
+
 end
--- a/src/elaborate.sml	Tue Jul 01 16:06:58 2008 -0400
+++ b/src/elaborate.sml	Thu Jul 03 11:04:25 2008 -0400
@@ -251,13 +251,13 @@
             val ku1 = kunif loc
             val ku2 = kunif loc
 
-            val denv' = D.assert env denv (c1', c2')
-            val (c', k, gs3) = elabCon (env, denv') c
+            val (denv', gs3) = D.assert env denv (c1', c2')
+            val (c', k, gs4) = elabCon (env, denv') c
         in
             checkKind env c1' k1 (L'.KRecord ku1, loc);
             checkKind env c2' k2 (L'.KRecord ku2, loc);
 
-            ((L'.TDisjoint (c1', c2', c'), loc), k, gs1 @ gs2 @ gs3)
+            ((L'.TDisjoint (c1', c2', c'), loc), k, gs1 @ gs2 @ gs3 @ gs4)
         end
       | L.TRecord c =>
         let
@@ -330,13 +330,13 @@
             val ku1 = kunif loc
             val ku2 = kunif loc
 
-            val denv' = D.assert env denv (c1', c2')
-            val (c', k, gs3) = elabCon (env, denv') c
+            val (denv', gs3) = D.assert env denv (c1', c2')
+            val (c', k, gs4) = elabCon (env, denv') c
         in
             checkKind env c1' k1 (L'.KRecord ku1, loc);
             checkKind env c2' k2 (L'.KRecord ku2, loc);
 
-            ((L'.CDisjoint (c1', c2', c'), loc), k, gs1 @ gs2 @ gs3)
+            ((L'.CDisjoint (c1', c2', c'), loc), k, gs1 @ gs2 @ gs3 @ gs4)
         end
 
       | L.CName s =>
@@ -369,8 +369,7 @@
                                            let
                                                val r2 = (L'.CRecord (k, [xc']), loc)
                                            in
-                                               map (fn cs => (loc, env, denv, cs)) (D.prove env denv (r1, r2, loc))
-                                               @ ds
+                                               D.prove env denv (r1, r2, loc) @ ds
                                            end)
                                  ds rest
                     in
@@ -389,7 +388,7 @@
             checkKind env c1' k1 k;
             checkKind env c2' k2 k;
             ((L'.CConcat (c1', c2'), loc), k,
-             map (fn cs => (loc, env, denv, cs)) (D.prove env denv (c1', c2', loc)) @ gs1 @ gs2)
+             D.prove env denv (c1', c2', loc) @ gs1 @ gs2)
         end
       | L.CFold =>
         let
@@ -545,23 +544,7 @@
       | L'.CError => kerror
       | L'.CUnif (_, k, _, _) => k
 
-fun hnormCon (env, denv) c =
-    let
-        val cAll as (c, loc) = ElabOps.hnormCon env c
-
-        fun doDisj (c1, c2, c) =
-            let
-                val (c, gs) = hnormCon (env, denv) c
-            in
-                (c,
-                 map (fn cs => (loc, env, denv, cs)) (D.prove env denv (c1, c2, loc)) @ gs)
-            end
-    in
-        case c of
-            L'.CDisjoint cs => doDisj cs
-          | L'.TDisjoint cs => doDisj cs
-          | _ => (cAll, [])
-    end
+val hnormCon = D.hnormCon
 
 fun unifyRecordCons (env, denv) (c1, c2) =
     let
@@ -703,9 +686,9 @@
     let
         val (c1, gs1) = hnormCon (env, denv) c1
         val (c2, gs2) = hnormCon (env, denv) c2
+        val gs3 = unifyCons'' (env, denv) c1 c2
     in
-        unifyCons'' (env, denv) c1 c2;
-        gs1 @ gs2
+        gs1 @ gs2 @ gs3
     end
     
 and unifyCons'' (env, denv) (c1All as (c1, _)) (c2All as (c2, _)) =
@@ -1040,13 +1023,13 @@
             val ku1 = kunif loc
             val ku2 = kunif loc
 
-            val denv' = D.assert env denv (c1', c2')
-            val (e', t, gs3) = elabExp (env, denv') e
+            val (denv', gs3) = D.assert env denv (c1', c2')
+            val (e', t, gs4) = elabExp (env, denv') e
         in
             checkKind env c1' k1 (L'.KRecord ku1, loc);
             checkKind env c2' k2 (L'.KRecord ku2, loc);
 
-            (e', (L'.TDisjoint (c1', c2', t), loc), gs1 @ gs2 @ gs3)
+            (e', (L'.TDisjoint (c1', c2', t), loc), gs1 @ gs2 @ gs3 @ gs4)
         end
 
       | L.ERecord xes =>
@@ -1075,8 +1058,7 @@
                                                val xc' = (x', t')
                                                val r2 = (L'.CRecord (k, [xc']), loc)
                                            in
-                                               map (fn cs => (loc, env, denv, cs)) (D.prove env denv (r1, r2, loc))
-                                               @ gs
+                                               D.prove env denv (r1, r2, loc) @ gs
                                            end)
                                  gs rest
                     in
@@ -1100,9 +1082,7 @@
             val gs3 =
                 checkCon (env, denv) e' et
                          (L'.TRecord (L'.CConcat (first, rest), loc), loc)
-            val gs4 =
-                map (fn cs => (loc, env, denv, cs))
-                (D.prove env denv (first, rest, loc))
+            val gs4 = D.prove env denv (first, rest, loc)
         in
             ((L'.EField (e', c', {field = ft, rest = rest}), loc), ft, gs1 @ gs2 @ gs3 @ gs4)
         end
@@ -1287,12 +1267,12 @@
             val (c1', k1, gs1) = elabCon (env, denv) c1
             val (c2', k2, gs2) = elabCon (env, denv) c2
 
-            val denv = D.assert env denv (c1', c2')
+            val (denv, gs3) = D.assert env denv (c1', c2')
         in
             checkKind env c1' k1 (L'.KRecord (kunif loc), loc);
             checkKind env c2' k2 (L'.KRecord (kunif loc), loc);
 
-            ([(L'.SgiConstraint (c1', c2'), loc)], (env, denv, gs1 @ gs2))
+            ([(L'.SgiConstraint (c1', c2'), loc)], (env, denv, gs1 @ gs2 @ gs3))
         end
 
 and elabSgn (env, denv) (sgn, loc) =
@@ -1484,7 +1464,16 @@
             val denv = case cso of
                            NONE => (strError env (UnboundStr (loc, str));
                                     denv)
-                         | SOME cs => foldl (fn ((c1, c2), denv) => D.assert env denv (c1, c2)) denv cs
+                         | SOME cs => foldl (fn ((c1, c2), denv) =>
+                                                let
+                                                    val (denv, gs) = D.assert env denv (c1, c2)
+                                                in
+                                                    case gs of
+                                                        [] => ()
+                                                      | _ => raise Fail "dopenConstraints: Sub-constraints remain";
+
+                                                    denv
+                                                end) denv cs
         in
             denv
         end
@@ -1500,7 +1489,10 @@
 
 fun sgiBindsD (env, denv) (sgi, _) =
     case sgi of
-        L'.SgiConstraint (c1, c2) => D.assert env denv (c1, c2)
+        L'.SgiConstraint (c1, c2) =>
+        (case D.assert env denv (c1, c2) of
+             (denv, []) => denv
+           | _ => raise Fail "sgiBindsD: Sub-constraints remain")
       | _ => denv
 
 fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) =
@@ -1634,7 +1626,15 @@
                                  case sgi1 of
                                      L'.SgiConstraint (c1, d1) =>
                                      if consEq (env, denv) (c1, c2) andalso consEq (env, denv) (d1, d2) then
-                                         SOME (env, D.assert env denv (c2, d2))
+                                         let
+                                             val (denv, gs) = D.assert env denv (c2, d2)
+                                         in
+                                             case gs of
+                                                 [] => ()
+                                               | _ => raise Fail "subSgn: Sub-constraints remain";
+
+                                             SOME (env, denv)
+                                         end
                                      else
                                          NONE
                                    | _ => NONE)
@@ -1793,14 +1793,14 @@
         let
             val (c1', k1, gs1) = elabCon (env, denv) c1
             val (c2', k2, gs2) = elabCon (env, denv) c2
-            val gs3 = map (fn cs => (loc, env, denv, cs)) (D.prove env denv (c1', c2', loc))
+            val gs3 = D.prove env denv (c1', c2', loc)
 
-            val denv' = D.assert env denv (c1', c2')
+            val (denv', gs4) = D.assert env denv (c1', c2')
         in
             checkKind env c1' k1 (L'.KRecord (kunif loc), loc);
             checkKind env c2' k2 (L'.KRecord (kunif loc), loc);
 
-            ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3))
+            ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3 @ gs4))
         end
 
       | L.DOpenConstraints (m, ms) =>
@@ -1982,13 +1982,15 @@
         if ErrorMsg.anyErrors () then
             ()
         else
-            app (fn (loc, env, denv, (c1, c2)) =>
+            app (fn (loc, env, denv, c1, c2) =>
                     case D.prove env denv (c1, c2, loc) of
                         [] => ()
                       | _ =>
                         (ErrorMsg.errorAt loc "Couldn't prove field name disjointness";
                          eprefaces' [("Con 1", p_con env c1),
-                                     ("Con 2", p_con env c2)])) gs;
+                                     ("Con 2", p_con env c2),
+                                     ("Hnormed 1", p_con env (ElabOps.hnormCon env c1)),
+                                     ("Hnormed 2", p_con env (ElabOps.hnormCon env c2))])) gs;
 
         (L'.DFfiStr ("Basis", basis_n, sgn), ErrorMsg.dummySpan) :: ds @ file
     end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/cfold_disj.lac	Thu Jul 03 11:04:25 2008 -0400
@@ -0,0 +1,5 @@
+con id = fold (fn nm => fn t :: Type => fn acc => [nm] ~ acc => [nm = t] ++ acc) []
+
+con idT = id [D = int, E = float]
+
+val idV = fn x : $idT => x.E