diff src/disjoint.sml @ 207:cc68da3801bc

Non-star SELECT
author Adam Chlipala <adamc@hcoop.net>
date Thu, 14 Aug 2008 18:35:08 -0400
parents 94ef20a31550
children 326fb4686f60
line wrap: on
line diff
--- a/src/disjoint.sml	Thu Aug 14 15:27:35 2008 -0400
+++ b/src/disjoint.sml	Thu Aug 14 18:35:08 2008 -0400
@@ -30,7 +30,7 @@
 open Elab
 open ElabOps
 
-datatype piece =
+datatype piece_fst =
          NameC of string
        | NameR of int
        | NameN of int
@@ -39,6 +39,8 @@
        | RowN of int
        | RowM of int * string list * string
 
+type piece = piece_fst * int list
+
 fun p2s p =
     case p of
         NameC s => "NameC(" ^ s ^ ")"
@@ -55,20 +57,9 @@
 
 type ord_key = piece
 
-fun join (o1, o2) =
-    case o1 of
-        EQUAL => o2 ()
-      | v => v
+open Order
 
-fun joinL f (os1, os2) =
-    case (os1, os2) of
-        (nil, nil) => EQUAL
-      | (nil, _) => LESS
-      | (h1 :: t1, h2 :: t2) =>
-        join (f (h1, h2), fn () => joinL f (t1, t2))
-      | (_ :: _, nil) => GREATER
-
-fun compare (p1, p2) =
+fun compare' (p1, p2) =
     case (p1, p2) of
         (NameC s1, NameC s2) => String.compare (s1, s2)
       | (NameR n1, NameR n2) => Int.compare (n1, n2)
@@ -102,6 +93,10 @@
       | (RowN _, _) => LESS
       | (_, RowN _) => GREATER
 
+fun compare ((p1, ns1), (p2, ns2)) =
+    join (compare' (p1, p2),
+          fn () => joinL Int.compare (ns1, ns2))
+
 end
 
 structure PS = BinarySetFn(PK)
@@ -116,7 +111,7 @@
 fun nameToRow (c, loc) =
     (CRecord ((KUnit, loc), [((c, loc), (CUnit, loc))]), loc)
 
-fun pieceToRow (p, loc) =
+fun pieceToRow' (p, loc) =
     case p of
         NameC s => nameToRow (CName s, loc)
       | NameR n => nameToRow (CRel n, loc)
@@ -126,16 +121,21 @@
       | RowN n => (CNamed n, loc)
       | RowM (n, xs, x) => (CModProj (n, xs, x), loc)
 
+fun pieceToRow ((p, ns), loc) =
+    foldl (fn (n, c) => (CProj (c, n), loc)) (pieceToRow' (p, loc)) ns
+
 datatype piece' =
          Piece of piece
        | Unknown of con
 
-fun pieceEnter p =
+fun pieceEnter' p =
     case p of
         NameR n => NameR (n + 1)
       | RowR n => RowR (n + 1)
       | _ => p
 
+fun pieceEnter (p, n) = (pieceEnter' p, n)
+
 fun enter denv =
     PM.foldli (fn (p, pset, denv') =>
                   PM.insert (denv', pieceEnter p, PS.map pieceEnter pset))
@@ -143,7 +143,7 @@
 
 fun prove1 denv (p1, p2) =
     case (p1, p2) of
-        (NameC s1, NameC s2) => s1 <> s2
+        ((NameC s1, _), (NameC s2, _)) => s1 <> s2
       | _ =>
         case PM.find (denv, p1) of
             NONE => false
@@ -151,15 +151,29 @@
 
 fun decomposeRow (env, denv) c =
     let
+        fun decomposeProj c =
+            let
+                val (c, gs) = hnormCon (env, denv) c
+            in
+                case #1 c of
+                    CProj (c, n) =>
+                    let
+                        val (c', ns, gs') = decomposeProj c
+                    in
+                        (c', ns @ [n], gs @ gs')
+                    end
+                  | _ => (c, [], gs)
+            end
+
         fun decomposeName (c, (acc, gs)) =
             let
-                val (cAll as (c, _), gs') = hnormCon (env, denv) c
+                val (cAll as (c, _), ns, gs') = decomposeProj c
 
                 val acc = case c of
-                              CName s => Piece (NameC s) :: acc
-                            | CRel n => Piece (NameR n) :: acc
-                            | CNamed n => Piece (NameN n) :: acc
-                            | CModProj (m1, ms, x) => Piece (NameM (m1, ms, x)) :: acc
+                              CName s => Piece (NameC s, ns) :: acc
+                            | CRel n => Piece (NameR n, ns) :: acc
+                            | CNamed n => Piece (NameN n, ns) :: acc
+                            | CModProj (m1, ms, x) => Piece (NameM (m1, ms, x), ns) :: acc
                             | _ => Unknown cAll :: acc
             in
                 (acc, gs' @ gs)
@@ -167,15 +181,15 @@
 
         fun decomposeRow (c, (acc, gs)) =
             let
-                val (cAll as (c, _), gs') = hnormCon (env, denv) c
+                val (cAll as (c, _), ns, gs') = decomposeProj c
                 val gs = gs' @ gs
             in
                 case c of
                     CRecord (_, xcs) => foldl (fn ((x, _), acc_gs) => decomposeName (x, acc_gs)) (acc, gs) xcs
                   | CConcat (c1, c2) => decomposeRow (c1, decomposeRow (c2, (acc, gs)))
-                  | CRel n => (Piece (RowR n) :: acc, gs)
-                  | CNamed n => (Piece (RowN n) :: acc, gs)
-                  | CModProj (m1, ms, x) => (Piece (RowM (m1, ms, x)) :: acc, gs)
+                  | CRel n => (Piece (RowR n, ns) :: acc, gs)
+                  | CNamed n => (Piece (RowN n, ns) :: acc, gs)
+                  | CModProj (m1, ms, x) => (Piece (RowM (m1, ms, x), ns) :: acc, gs)
                   | _ => (Unknown cAll :: acc, gs)
             end
     in
@@ -200,7 +214,7 @@
             let
                 val pset = Option.getOpt (PM.find (denv, p), PS.empty)
                 val ps = case p of
-                             NameC _ => List.filter (fn NameC _ => false | _ => true) ps
+                             (NameC _, _) => List.filter (fn (NameC _, _) => false | _ => true) ps
                            | _ => ps
                 val pset = PS.addList (pset, ps)
             in