diff src/disjoint.sml @ 90:94ef20a31550

Fancier head normalization pushed inside of Disjoint
author Adam Chlipala <adamc@hcoop.net>
date Thu, 03 Jul 2008 11:04:25 -0400
parents 7bab29834cd6
children cc68da3801bc
line wrap: on
line diff
--- a/src/disjoint.sml	Tue Jul 01 16:06:58 2008 -0400
+++ b/src/disjoint.sml	Thu Jul 03 11:04:25 2008 -0400
@@ -109,6 +109,8 @@
 
 type env = PS.set PM.map
 
+type goal = ErrorMsg.span * ElabEnv.env * env * Elab.con * Elab.con
+
 val empty = PM.empty
 
 fun nameToRow (c, loc) =
@@ -128,55 +130,6 @@
          Piece of piece
        | Unknown of con
 
-fun decomposeRow env c =
-    let
-        fun decomposeName (c, acc) =
-            case #1 (hnormCon env c) of
-                CName s => Piece (NameC s) :: acc
-              | CRel n => Piece (NameR n) :: acc
-              | CNamed n => Piece (NameN n) :: acc
-              | CModProj (m1, ms, x) => Piece (NameM (m1, ms, x)) :: acc
-              | _ => Unknown c :: acc
-
-        fun decomposeRow (c, acc) =
-            case #1 (hnormCon env c) of
-                CRecord (_, xcs) => foldl (fn ((x, _), acc) => decomposeName (x, acc)) acc xcs
-              | CConcat (c1, c2) => decomposeRow (c1, decomposeRow (c2, acc))
-              | CRel n => Piece (RowR n) :: acc
-              | CNamed n => Piece (RowN n) :: acc
-              | CModProj (m1, ms, x) => Piece (RowM (m1, ms, x)) :: acc
-              | _ => Unknown c :: acc
-    in
-        decomposeRow (c, [])
-    end
-
-fun assert env denv (c1, c2) =
-    let
-        val ps1 = decomposeRow env c1
-        val ps2 = decomposeRow env c2
-
-        val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
-        val ps1 = unUnknown ps1
-        val ps2 = unUnknown ps2
-
-        (*val () = print "APieces1:\n"
-        val () = app pp ps1
-        val () = print "APieces2:\n"
-        val () = app pp ps2*)
-
-        fun assertPiece ps (p, denv) =
-            let
-                val pset = Option.getOpt (PM.find (denv, p), PS.empty)
-                val pset = PS.addList (pset, ps)
-            in
-                PM.insert (denv, p, pset)
-            end
-
-        val denv = foldl (assertPiece ps2) denv ps1
-    in
-        foldl (assertPiece ps1) denv ps2
-    end
-
 fun pieceEnter p =
     case p of
         NameR n => NameR (n + 1)
@@ -196,16 +149,79 @@
             NONE => false
           | SOME pset => PS.member (pset, p2)
 
-fun prove env denv (c1, c2, loc) =
+fun decomposeRow (env, denv) c =
     let
-        val ps1 = decomposeRow env c1
-        val ps2 = decomposeRow env c2
+        fun decomposeName (c, (acc, gs)) =
+            let
+                val (cAll as (c, _), gs') = hnormCon (env, denv) c
+
+                val acc = case c of
+                              CName s => Piece (NameC s) :: acc
+                            | CRel n => Piece (NameR n) :: acc
+                            | CNamed n => Piece (NameN n) :: acc
+                            | CModProj (m1, ms, x) => Piece (NameM (m1, ms, x)) :: acc
+                            | _ => Unknown cAll :: acc
+            in
+                (acc, gs' @ gs)
+            end
+
+        fun decomposeRow (c, (acc, gs)) =
+            let
+                val (cAll as (c, _), gs') = hnormCon (env, denv) c
+                val gs = gs' @ gs
+            in
+                case c of
+                    CRecord (_, xcs) => foldl (fn ((x, _), acc_gs) => decomposeName (x, acc_gs)) (acc, gs) xcs
+                  | CConcat (c1, c2) => decomposeRow (c1, decomposeRow (c2, (acc, gs)))
+                  | CRel n => (Piece (RowR n) :: acc, gs)
+                  | CNamed n => (Piece (RowN n) :: acc, gs)
+                  | CModProj (m1, ms, x) => (Piece (RowM (m1, ms, x)) :: acc, gs)
+                  | _ => (Unknown cAll :: acc, gs)
+            end
+    in
+        decomposeRow (c, ([], []))
+    end
+
+and assert env denv (c1, c2) =
+    let
+        val (ps1, gs1) = decomposeRow (env, denv) c1
+        val (ps2, gs2) = decomposeRow (env, denv) c2
+
+        val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
+        val ps1 = unUnknown ps1
+        val ps2 = unUnknown ps2
+
+        (*val () = print "APieces1:\n"
+        val () = app pp ps1
+        val () = print "APieces2:\n"
+        val () = app pp ps2*)
+
+        fun assertPiece ps (p, denv) =
+            let
+                val pset = Option.getOpt (PM.find (denv, p), PS.empty)
+                val ps = case p of
+                             NameC _ => List.filter (fn NameC _ => false | _ => true) ps
+                           | _ => ps
+                val pset = PS.addList (pset, ps)
+            in
+                PM.insert (denv, p, pset)
+            end
+
+        val denv = foldl (assertPiece ps2) denv ps1
+    in
+        (foldl (assertPiece ps1) denv ps2, gs1 @ gs2)
+    end
+
+and prove env denv (c1, c2, loc) =
+    let
+        val (ps1, gs1) = decomposeRow (env, denv) c1
+        val (ps2, gs2) = decomposeRow (env, denv) c2
 
         val hasUnknown = List.exists (fn Unknown _ => true | _ => false)
         val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
     in
         if hasUnknown ps1 orelse hasUnknown ps2 then
-            [(c1, c2)]
+            [(loc, env, denv, c1, c2)]
         else
             let
                 val ps1 = unUnknown ps1
@@ -222,9 +238,26 @@
                                     if prove1 denv (p1, p2) then
                                         rem
                                     else
-                                        (pieceToRow (p1, loc), pieceToRow (p2, loc)) :: rem) rem ps2)
-                      [] ps1
+                                        (loc, env, denv, pieceToRow (p1, loc), pieceToRow (p2, loc)) :: rem) rem ps2)
+                      (gs1 @ gs2) ps1
             end
     end
 
+and hnormCon (env, denv) c =
+    let
+        val cAll as (c, loc) = ElabOps.hnormCon env c
+
+        fun doDisj (c1, c2, c) =
+            let
+                val (c, gs) = hnormCon (env, denv) c
+            in
+                (c, prove env denv (c1, c2, loc) @ gs)
+            end
+    in
+        case c of
+            CDisjoint cs => doDisj cs
+          | TDisjoint cs => doDisj cs
+          | _ => (cAll, [])
+    end
+
 end