diff src/disjoint.sml @ 88:7bab29834cd6

Constraints in modules
author Adam Chlipala <adamc@hcoop.net>
date Tue, 01 Jul 2008 15:58:02 -0400
parents e86370850c30
children 94ef20a31550
line wrap: on
line diff
--- a/src/disjoint.sml	Tue Jul 01 13:23:46 2008 -0400
+++ b/src/disjoint.sml	Tue Jul 01 15:58:02 2008 -0400
@@ -30,70 +30,86 @@
 open Elab
 open ElabOps
 
-structure SS = BinarySetFn(struct
-                           type ord_key = string
-                           val compare = String.compare
-                           end)
-
-structure IS = IntBinarySet
-structure IM = IntBinaryMap
-
-type name_ineqs = {
-     namesC : SS.set,
-     namesR : IS.set,
-     namesN : IS.set
-}
-
-val name_default = {
-    namesC = SS.empty,
-    namesR = IS.empty,
-    namesN = IS.empty
-}
-
-type row_ineqs = {
-     namesC : SS.set,
-     namesR : IS.set,
-     namesN : IS.set,
-     rowsR : IS.set,
-     rowsN : IS.set
-}
-
-val row_default = {
-    namesC = SS.empty,
-    namesR = IS.empty,
-    namesN = IS.empty,
-    rowsR = IS.empty,
-    rowsN = IS.empty
-}
-
-fun nameToRow_ineqs {namesC, namesR, namesN} =
-    {namesC = namesC,
-     namesR = namesR,
-     namesN = namesN,
-     rowsR = IS.empty,
-     rowsN = IS.empty}
-
-type env = {
-     namesR : name_ineqs IM.map,
-     namesN : name_ineqs IM.map,
-     rowsR : row_ineqs IM.map,
-     rowsN : row_ineqs IM.map
-}
-
-val empty = {
-    namesR = IM.empty,
-    namesN = IM.empty,
-    rowsR = IM.empty,
-    rowsN = IM.empty
-}
-
 datatype piece =
          NameC of string
        | NameR of int
        | NameN of int
+       | NameM of int * string list * string
        | RowR of int
        | RowN of int
-       | Unknown
+       | RowM of int * string list * string
+
+fun p2s p =
+    case p of
+        NameC s => "NameC(" ^ s ^ ")"
+      | NameR n => "NameR(" ^ Int.toString n ^ ")"
+      | NameN n => "NameN(" ^ Int.toString n ^ ")"
+      | NameM (n, _, s) => "NameR(" ^ Int.toString n ^ ", " ^ s ^ ")"
+      | RowR n => "RowR(" ^ Int.toString n ^ ")"
+      | RowN n => "RowN(" ^ Int.toString n ^ ")"
+      | RowM (n, _, s) => "RowR(" ^ Int.toString n ^ ", " ^ s ^ ")"
+
+fun pp p = print (p2s p ^ "\n")
+
+structure PK = struct
+
+type ord_key = piece
+
+fun join (o1, o2) =
+    case o1 of
+        EQUAL => o2 ()
+      | v => v
+
+fun joinL f (os1, os2) =
+    case (os1, os2) of
+        (nil, nil) => EQUAL
+      | (nil, _) => LESS
+      | (h1 :: t1, h2 :: t2) =>
+        join (f (h1, h2), fn () => joinL f (t1, t2))
+      | (_ :: _, nil) => GREATER
+
+fun compare (p1, p2) =
+    case (p1, p2) of
+        (NameC s1, NameC s2) => String.compare (s1, s2)
+      | (NameR n1, NameR n2) => Int.compare (n1, n2)
+      | (NameN n1, NameN n2) => Int.compare (n1, n2)
+      | (NameM (n1, ss1, s1), NameM (n2, ss2, s2)) =>
+        join (Int.compare (n1, n2),
+           fn () => join (String.compare (s1, s2), fn () =>
+                                                      joinL String.compare (ss1, ss2)))
+      | (RowR n1, RowR n2) => Int.compare (n1, n2)
+      | (RowN n1, RowN n2) => Int.compare (n1, n2)
+      | (RowM (n1, ss1, s1), RowM (n2, ss2, s2)) =>
+        join (Int.compare (n1, n2),
+           fn () => join (String.compare (s1, s2), fn () =>
+                                                      joinL String.compare (ss1, ss2)))
+
+      | (NameC _, _) => LESS
+      | (_, NameC _) => GREATER
+
+      | (NameR _, _) => LESS
+      | (_, NameR _) => GREATER
+
+      | (NameN _, _) => LESS
+      | (_, NameN _) => GREATER
+
+      | (NameM _, _) => LESS
+      | (_, NameM _) => GREATER
+
+      | (RowR _, _) => LESS
+      | (_, RowR _) => GREATER
+
+      | (RowN _, _) => LESS
+      | (_, RowN _) => GREATER
+
+end
+
+structure PS = BinarySetFn(PK)
+structure PM = BinaryMapFn(PK)
+
+type env = PS.set PM.map
+
+val empty = PM.empty
 
 fun nameToRow (c, loc) =
     (CRecord ((KUnit, loc), [((c, loc), (CUnit, loc))]), loc)
@@ -103,190 +119,112 @@
         NameC s => nameToRow (CName s, loc)
       | NameR n => nameToRow (CRel n, loc)
       | NameN n => nameToRow (CNamed n, loc)
+      | NameM (n, xs, x) => nameToRow (CModProj (n, xs, x), loc)
       | RowR n => (CRel n, loc)
-      | RowN n => (CRel n, loc)
-      | Unknown => raise Fail "Unknown to row"
+      | RowN n => (CNamed n, loc)
+      | RowM (n, xs, x) => (CModProj (n, xs, x), loc)
+
+datatype piece' =
+         Piece of piece
+       | Unknown of con
 
 fun decomposeRow env c =
     let
         fun decomposeName (c, acc) =
             case #1 (hnormCon env c) of
-                CName s => NameC s :: acc
-              | CRel n => NameR n :: acc
-              | CNamed n => NameN n :: acc
-              | _ => (print "Unknown name\n"; Unknown :: acc)
-                     
+                CName s => Piece (NameC s) :: acc
+              | CRel n => Piece (NameR n) :: acc
+              | CNamed n => Piece (NameN n) :: acc
+              | CModProj (m1, ms, x) => Piece (NameM (m1, ms, x)) :: acc
+              | _ => Unknown c :: acc
+
         fun decomposeRow (c, acc) =
             case #1 (hnormCon env c) of
                 CRecord (_, xcs) => foldl (fn ((x, _), acc) => decomposeName (x, acc)) acc xcs
               | CConcat (c1, c2) => decomposeRow (c1, decomposeRow (c2, acc))
-              | CRel n => RowR n :: acc
-              | CNamed n => RowN n :: acc
-              | _ => (print "Unknown row\n"; Unknown :: acc)
+              | CRel n => Piece (RowR n) :: acc
+              | CNamed n => Piece (RowN n) :: acc
+              | CModProj (m1, ms, x) => Piece (RowM (m1, ms, x)) :: acc
+              | _ => Unknown c :: acc
     in
         decomposeRow (c, [])
     end
 
-fun assertPiece_name (ps, ineqs : name_ineqs) =
-    {namesC = foldl (fn (p', namesC) =>
-                        case p' of
-                            NameC s => SS.add (namesC, s)
-                          | _ => namesC) (#namesC ineqs) ps,
-     namesR = foldl (fn (p', namesR) =>
-                        case p' of
-                            NameR n => IS.add (namesR, n)
-                          | _ => namesR) (#namesR ineqs) ps,
-     namesN = foldl (fn (p', namesN) =>
-                        case p' of
-                            NameN n => IS.add (namesN, n)
-                          | _ => namesN) (#namesN ineqs) ps}
-
-fun assertPiece_row (ps, ineqs : row_ineqs) =
-    {namesC = foldl (fn (p', namesC) =>
-                        case p' of
-                            NameC s => SS.add (namesC, s)
-                          | _ => namesC) (#namesC ineqs) ps,
-     namesR = foldl (fn (p', namesR) =>
-                        case p' of
-                            NameR n => IS.add (namesR, n)
-                          | _ => namesR) (#namesR ineqs) ps,
-     namesN = foldl (fn (p', namesN) =>
-                        case p' of
-                            NameN n => IS.add (namesN, n)
-                          | _ => namesN) (#namesN ineqs) ps,
-     rowsR = foldl (fn (p', rowsR) =>
-                        case p' of
-                            RowR n => IS.add (rowsR, n)
-                          | _ => rowsR) (#rowsR ineqs) ps,
-     rowsN = foldl (fn (p', rowsN) =>
-                        case p' of
-                            RowN n => IS.add (rowsN, n)
-                          | _ => rowsN) (#rowsN ineqs) ps}
-
-fun assertPiece ps (p, denv) =
-    case p of
-        Unknown => denv
-      | NameC _ => denv
-
-      | NameR n =>
-        let
-            val ineqs = Option.getOpt (IM.find (#namesR denv, n), name_default)
-            val ineqs = assertPiece_name (ps, ineqs)
-        in
-            {namesR = IM.insert (#namesR denv, n, ineqs),
-             namesN = #namesN denv,
-             rowsR = #rowsR denv,
-             rowsN = #rowsN denv}
-        end
-
-      | NameN n =>
-        let
-            val ineqs = Option.getOpt (IM.find (#namesN denv, n), name_default)
-            val ineqs = assertPiece_name (ps, ineqs)
-        in
-            {namesR = #namesR denv,
-             namesN = IM.insert (#namesN denv, n, ineqs),
-             rowsR = #rowsR denv,
-             rowsN = #rowsN denv}
-        end
-
-      | RowR n =>
-        let
-            val ineqs = Option.getOpt (IM.find (#rowsR denv, n), row_default)
-            val ineqs = assertPiece_row (ps, ineqs)
-        in
-            {namesR = #namesR denv,
-             namesN = #namesN denv,
-             rowsR = IM.insert (#rowsR denv, n, ineqs),
-             rowsN = #rowsN denv}
-        end
-
-      | RowN n =>
-        let
-            val ineqs = Option.getOpt (IM.find (#rowsN denv, n), row_default)
-            val ineqs = assertPiece_row (ps, ineqs)
-        in
-            {namesR = #namesR denv,
-             namesN = #namesN denv,
-             rowsR = #rowsR denv,
-             rowsN = IM.insert (#rowsN denv, n, ineqs)}
-        end
-
 fun assert env denv (c1, c2) =
     let
         val ps1 = decomposeRow env c1
         val ps2 = decomposeRow env c2
 
+        val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
+        val ps1 = unUnknown ps1
+        val ps2 = unUnknown ps2
+
+        (*val () = print "APieces1:\n"
+        val () = app pp ps1
+        val () = print "APieces2:\n"
+        val () = app pp ps2*)
+
+        fun assertPiece ps (p, denv) =
+            let
+                val pset = Option.getOpt (PM.find (denv, p), PS.empty)
+                val pset = PS.addList (pset, ps)
+            in
+                PM.insert (denv, p, pset)
+            end
+
         val denv = foldl (assertPiece ps2) denv ps1
     in
         foldl (assertPiece ps1) denv ps2
     end
 
-fun nameEnter {namesC, namesR, namesN} =
-    {namesC = namesC,
-     namesR = IS.map (fn n => n + 1) namesR,
-     namesN = namesN}
+fun pieceEnter p =
+    case p of
+        NameR n => NameR (n + 1)
+      | RowR n => RowR (n + 1)
+      | _ => p
 
-fun rowEnter {namesC, namesR, namesN, rowsR, rowsN} =
-    {namesC = namesC,
-     namesR = IS.map (fn n => n + 1) namesR,
-     namesN = namesN,
-     rowsR = IS.map (fn n => n + 1) rowsR,
-     rowsN = rowsN}
-
-fun enter {namesR, namesN, rowsR, rowsN} =
-    {namesR = IM.foldli (fn (n, ineqs, namesR) => IM.insert (namesR, n+1, nameEnter ineqs)) IM.empty namesR,
-     namesN = IM.map nameEnter namesN,
-     rowsR = IM.foldli (fn (n, ineqs, rowsR) => IM.insert (rowsR, n+1, rowEnter ineqs)) IM.empty rowsR,
-     rowsN = IM.map rowEnter rowsN}
-
-fun getIneqs (denv : env) p =
-    case p of
-        Unknown => raise Fail "getIneqs: Unknown"
-      | NameC _ => raise Fail "getIneqs: NameC"
-      | NameR n => nameToRow_ineqs (Option.getOpt (IM.find (#namesR denv, n), name_default))
-      | NameN n => nameToRow_ineqs (Option.getOpt (IM.find (#namesN denv, n), name_default))
-      | RowR n => Option.getOpt (IM.find (#rowsR denv, n), row_default)
-      | RowN n => Option.getOpt (IM.find (#rowsN denv, n), row_default)
-
-fun prove1' denv (p1, p2) =
-    let
-        val {namesC, namesR, namesN, rowsR, rowsN} = getIneqs denv p1
-    in
-        case p2 of
-            Unknown => raise Fail "prove1': Unknown"
-          | NameC s => SS.member (namesC, s)
-          | NameR n => IS.member (namesR, n)
-          | NameN n => IS.member (namesN, n)
-          | RowR n => IS.member (rowsR, n)
-          | RowN n => IS.member (rowsN, n)
-    end
+fun enter denv =
+    PM.foldli (fn (p, pset, denv') =>
+                  PM.insert (denv', pieceEnter p, PS.map pieceEnter pset))
+    PM.empty denv
 
 fun prove1 denv (p1, p2) =
     case (p1, p2) of
         (NameC s1, NameC s2) => s1 <> s2
-      | (NameC _, _) => prove1' denv (p2, p1)
-      | (_, RowR _) => prove1' denv (p2, p1)
-      | (_, RowN _) => prove1' denv (p2, p1)
-      | _ => prove1' denv (p1, p2)
+      | _ =>
+        case PM.find (denv, p1) of
+            NONE => false
+          | SOME pset => PS.member (pset, p2)
 
 fun prove env denv (c1, c2, loc) =
     let
         val ps1 = decomposeRow env c1
         val ps2 = decomposeRow env c2
 
-        val hasUnknown = List.exists (fn p => p = Unknown)
+        val hasUnknown = List.exists (fn Unknown _ => true | _ => false)
+        val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
     in
         if hasUnknown ps1 orelse hasUnknown ps2 then
             [(c1, c2)]
         else
-            foldl (fn (p1, rem) =>
-                      foldl (fn (p2, rem) =>
-                                if prove1 denv (p1, p2) then
-                                    rem
-                                else
-                                    (pieceToRow (p1, loc), pieceToRow (p2, loc)) :: rem) rem ps2)
-            [] ps1
+            let
+                val ps1 = unUnknown ps1
+                val ps2 = unUnknown ps2
+
+            in
+                (*print "Pieces1:\n";
+                app pp ps1;
+                print "Pieces2:\n";
+                app pp ps2;*)
+
+                foldl (fn (p1, rem) =>
+                          foldl (fn (p2, rem) =>
+                                    if prove1 denv (p1, p2) then
+                                        rem
+                                    else
+                                        (pieceToRow (p1, loc), pieceToRow (p2, loc)) :: rem) rem ps2)
+                      [] ps1
+            end
     end
 
 end