diff src/iflow.sml @ 1244:1eedc9086e6c

Use key information in more places, and catch cases where one key completion depends on another having happened already
author Adam Chlipala <adamc@hcoop.net>
date Sun, 18 Apr 2010 13:00:36 -0400
parents e754dc92c47c
children 5c2555dfce8f
line wrap: on
line diff
--- a/src/iflow.sml	Sun Apr 18 10:56:39 2010 -0400
+++ b/src/iflow.sml	Sun Apr 18 13:00:36 2010 -0400
@@ -248,6 +248,18 @@
     val p_repOf : database -> exp Print.printer
 end = struct
 
+local
+    val count = ref 0
+in
+fun nodeId () =
+    let
+        val n = !count
+    in
+        count := n + 1;
+        n
+    end
+end
+
 exception Contradiction
 exception Undetermined
 
@@ -256,7 +268,8 @@
                            val compare = Prim.compare
                            end)
 
-datatype node = Node of {Rep : node ref option ref,
+datatype node = Node of {Id : int,
+                         Rep : node ref option ref,
                          Cons : node ref SM.map ref,
                          Variety : variety,
                          Known : bool ref}
@@ -300,8 +313,8 @@
     case !(#Rep (unNode n)) of
         SOME n => p_rep n
       | NONE =>
-        box [(*string (Int.toString (Unsafe.cast n) ^ ":"),
-             space,*)
+        box [string (Int.toString (#Id (unNode n)) ^ ":"),
+             space,
              case #Variety (unNode n) of
                  Nothing => string "?"
                | Dt0 s => string ("Dt0(" ^ s ^ ")")
@@ -372,7 +385,8 @@
                                 SOME r => repOf r
                               | NONE =>
                                 let
-                                    val r = ref (Node {Rep = ref NONE,
+                                    val r = ref (Node {Id = nodeId (),
+                                                       Rep = ref NONE,
                                                        Cons = ref SM.empty,
                                                        Variety = Prim p,
                                                        Known = ref true})
@@ -384,7 +398,8 @@
                               SOME r => repOf r
                             | NONE =>
                               let
-                                  val r = ref (Node {Rep = ref NONE,
+                                  val r = ref (Node {Id = nodeId (),
+                                                     Rep = ref NONE,
                                                      Cons = ref SM.empty,
                                                      Variety = Nothing,
                                                      Known = ref false})
@@ -397,7 +412,8 @@
                                             SOME r => repOf r
                                           | NONE =>
                                             let
-                                                val r = ref (Node {Rep = ref NONE,
+                                                val r = ref (Node {Id = nodeId (),
+                                                                   Rep = ref NONE,
                                                                    Cons = ref SM.empty,
                                                                    Variety = Dt0 f,
                                                                    Known = ref true})
@@ -414,7 +430,8 @@
                         SOME r => repOf r
                       | NONE =>
                         let
-                            val r' = ref (Node {Rep = ref NONE,
+                            val r' = ref (Node {Id = nodeId (),
+                                                Rep = ref NONE,
                                                 Cons = ref SM.empty,
                                                 Variety = Dt1 (f, r),
                                                 Known = ref (!(#Known (unNode r)))})
@@ -436,12 +453,14 @@
                       | Nothing =>
                         let
                             val cons = ref SM.empty
-                            val r' = ref (Node {Rep = ref NONE,
+                            val r' = ref (Node {Id = nodeId (),
+                                                Rep = ref NONE,
                                                 Cons = cons,
                                                 Variety = Nothing,
                                                 Known = ref (!(#Known (unNode r)))})
 
-                            val r'' = ref (Node {Rep = ref NONE,
+                            val r'' = ref (Node {Id = nodeId (),
+                                                 Rep = ref NONE,
                                                  Cons = #Cons (unNode r),
                                                  Variety = Dt1 (f, r'),
                                                  Known = #Known (unNode r)})
@@ -460,7 +479,8 @@
                     case List.find (fn (x : string * representative list, _) => x = (f, rs)) (!(#Funcs db)) of
                         NONE =>
                         let
-                            val r = ref (Node {Rep = ref NONE,
+                            val r = ref (Node {Id = nodeId (),
+                                               Rep = ref NONE,
                                                Cons = ref SM.empty,
                                                Variety = Nothing,
                                                Known = ref false})
@@ -487,7 +507,8 @@
                         let
                             val xes = foldl SM.insert' SM.empty xes
 
-                            val r' = ref (Node {Rep = ref NONE,
+                            val r' = ref (Node {Id = nodeId (),
+                                                Rep = ref NONE,
                                                 Cons = ref SM.empty,
                                                 Variety = Recrd (ref xes, true),
                                                 Known = ref false})
@@ -505,7 +526,8 @@
                         (case SM.find (!xes, f) of
                              SOME r => repOf r
                            | NONE => let
-                                  val r = ref (Node {Rep = ref NONE,
+                                  val r = ref (Node {Id = nodeId (),
+                                                     Rep = ref NONE,
                                                      Cons = ref SM.empty,
                                                      Variety = Nothing,
                                                      Known = ref (!(#Known (unNode r)))})
@@ -515,12 +537,14 @@
                               end)
                       | Nothing =>
                         let
-                            val r' = ref (Node {Rep = ref NONE,
+                            val r' = ref (Node {Id = nodeId (),
+                                                Rep = ref NONE,
                                                 Cons = ref SM.empty,
                                                 Variety = Nothing,
                                                 Known = ref (!(#Known (unNode r)))})
                                      
-                            val r'' = ref (Node {Rep = ref NONE,
+                            val r'' = ref (Node {Id = nodeId (),
+                                                 Rep = ref NONE,
                                                  Cons = #Cons (unNode r),
                                                  Variety = Recrd (ref (SM.insert (SM.empty, f, r')), false),
                                                  Known = #Known (unNode r)})
@@ -635,7 +659,8 @@
                              SOME r' => markEq (r, r')
                            | NONE =>
                              let
-                                 val r' = ref (Node {Rep = ref NONE,
+                                 val r' = ref (Node {Id = nodeId (),
+                                                     Rep = ref NONE,
                                                      Cons = ref SM.empty,
                                                      Variety = Dt0 f,
                                                      Known = ref false})
@@ -656,12 +681,14 @@
                                             raise Contradiction
                       | Nothing =>
                         let
-                            val r'' = ref (Node {Rep = ref NONE,
+                            val r'' = ref (Node {Id = nodeId (),
+                                                 Rep = ref NONE,
                                                  Cons = ref SM.empty,
                                                  Variety = Nothing,
                                                  Known = ref (!(#Known (unNode r)))})
 
-                            val r' = ref (Node {Rep = ref NONE,
+                            val r' = ref (Node {Id = nodeId (),
+                                                Rep = ref NONE,
                                                 Cons = ref SM.empty,
                                                 Variety = Dt1 (f, r''),
                                                 Known = #Known (unNode r)})
@@ -744,65 +771,6 @@
 
 val tabs = ref (SM.empty : (string list * string list list) SM.map)
 
-fun ccOf hyps =
-    let
-        val cc = Cc.database ()
-        val () = app (fn a => Cc.assert (cc, a)) hyps
-
-        (* Take advantage of table key information *)
-        fun findKeys hyps =
-            case hyps of
-                [] => ()
-              | AReln (Sql tab, [r1]) :: hyps =>
-                (case SM.find (!tabs, tab) of
-                     NONE => findKeys hyps
-                   | SOME (_, []) => findKeys hyps
-                   | SOME (_, ks) =>
-                     let
-                         fun finder hyps =
-                             case hyps of
-                                 [] => ()
-                               | AReln (Sql tab', [r2]) :: hyps =>
-                                 (if tab' = tab andalso
-                                     List.exists (List.all (fn f =>
-                                                               let
-                                                                   val r =
-                                                                       Cc.check (cc,
-                                                                                 AReln (Eq, [Proj (r1, f),
-                                                                                             Proj (r2, f)]))
-                                                               in
-                                                                   (*Print.prefaces "Fs"
-                                                                                    [("tab",
-                                                                                      Print.PD.string tab),
-                                                                                     ("r1",
-                                                                                      p_exp (Proj (r1, f))),
-                                                                                     ("r2",
-                                                                                      p_exp (Proj (r2, f))),
-                                                                                     ("r",
-                                                                                      Print.PD.string
-                                                                                          (Bool.toString r))];*)
-                                                                   r
-                                                               end)) ks then
-                                      ((*Print.prefaces "Key match" [("tab", Print.PD.string tab),
-                                                                     ("r1", p_exp r1),
-                                                                     ("r2", p_exp r2),
-                                                                     ("rp1", Cc.p_repOf cc r1),
-                                                                     ("rp2", Cc.p_repOf cc r2)];*)
-                                       Cc.assert (cc, AReln (Eq, [r1, r2])))
-                                  else
-                                      ();
-                                  finder hyps)
-                               | _ :: hyps => finder hyps
-                     in
-                         finder hyps;
-                         findKeys hyps
-                     end)
-              | _ :: hyps => findKeys hyps
-    in
-        findKeys hyps;
-        cc
-    end
-
 fun patCon pc =
     case pc of
         PConVar n => "C" ^ Int.toString n
@@ -1212,27 +1180,105 @@
 
 val hnames = ref 1
 
-type hyps = int * atom list
+type hyps = int * atom list * bool ref
 
 val db = Cc.database ()
-val path = ref ([] : (hyps * check) option ref list)
-val hyps = ref (0, [] : atom list)
+val path = ref ([] : ((int * atom list) * check) option ref list)
+val hyps = ref (0, [] : atom list, ref false)
 val nvar = ref 0
 
-fun setHyps (h as (n', hs)) =
+fun setHyps (n', hs) =
     let
-        val (n, _) = !hyps
+        val (n, _, _) = !hyps
     in
         if n' = n then
             ()
         else
-            (hyps := h;
+            (hyps := (n', hs, ref false);
              Cc.clear db;
              app (fn a => Cc.assert (db, a)) hs)
     end    
 
-type stashed = int * (hyps * check) option ref list * (int * atom list)
-fun stash () = (!nvar, !path, !hyps)
+fun useKeys () =
+    let
+        val changed = ref false
+
+        fun findKeys (hyps, acc) =
+            case hyps of
+                [] => acc
+              | (a as AReln (Sql tab, [r1])) :: hyps =>
+                (case SM.find (!tabs, tab) of
+                     NONE => findKeys (hyps, a :: acc)
+                   | SOME (_, []) => findKeys (hyps, a :: acc)
+                   | SOME (_, ks) =>
+                     let
+                         fun finder (hyps, acc) =
+                             case hyps of
+                                 [] => acc
+                               | (a as AReln (Sql tab', [r2])) :: hyps =>
+                                 if tab' = tab andalso
+                                    List.exists (List.all (fn f =>
+                                                              let
+                                                                  val r =
+                                                                      Cc.check (db,
+                                                                                AReln (Eq, [Proj (r1, f),
+                                                                                            Proj (r2, f)]))
+                                                              in
+                                                                  (*Print.prefaces "Fs"
+                                                                                   [("tab",
+                                                                                     Print.PD.string tab),
+                                                                                    ("r1",
+                                                                                     p_exp (Proj (r1, f))),
+                                                                                    ("r2",
+                                                                                     p_exp (Proj (r2, f))),
+                                                                                    ("r",
+                                                                                     Print.PD.string
+                                                                                         (Bool.toString r))];*)
+                                                                  r
+                                                              end)) ks then
+                                     (changed := true;
+                                      Cc.assert (db, AReln (Eq, [r1, r2]));
+                                      finder (hyps, acc))
+                                 else
+                                     finder (hyps, a :: acc)
+                               | a :: hyps => finder (hyps, a :: acc)
+
+                         val hyps = finder (hyps, [])
+                     in
+                         findKeys (hyps, acc)
+                     end)
+              | a :: hyps => findKeys (hyps, a :: acc)
+
+        fun loop hs =
+            let
+                val hs = findKeys (hs, [])
+            in
+                if !changed then
+                    (changed := false;
+                     loop hs)
+                else
+                    ()
+            end
+
+        val (_, hs, _) = !hyps
+    in
+        (*print "findKeys\n";*)
+        loop hs
+    end
+
+fun complete () =
+    let
+        val (_, _, bf) = !hyps
+    in
+        if !bf then
+            ()
+        else
+            (bf := true;
+             useKeys ())
+    end
+
+type stashed = int * ((int * atom list) * check) option ref list * (int * atom list)
+fun stash () = (!nvar, !path, (#1 (!hyps), #2 (!hyps)))
 fun reinstate (nv, p, h) =
     (nvar := nv;
      path := p;
@@ -1249,14 +1295,14 @@
 fun assert ats =
     let
         val n = !hnames
-        val (_, hs) = !hyps
+        val (_, hs, _) = !hyps
     in
         hnames := n + 1;
-        hyps := (n, ats @ hs);
+        hyps := (n, ats @ hs, ref false);
         app (fn a => Cc.assert (db, a)) ats
     end
 
-fun addPath c = path := ref (SOME (!hyps, c)) :: !path
+fun addPath c = path := ref (SOME ((#1 (!hyps), #2 (!hyps)), c)) :: !path
 
 val sendable = ref ([] : (atom list * exp list) list)
 
@@ -1268,7 +1314,7 @@
               | AReln (Sql tab, [Lvar lv]) :: goals =>
                 let
                     val saved = stash ()
-                    val (_, hyps) = !hyps
+                    val (_, hyps, _) = !hyps
 
                     fun tryAll unifs hyps =
                         case hyps of
@@ -1282,70 +1328,14 @@
                     tryAll unifs hyps
                 end
               | (g as AReln (r, es)) :: goals =>
-                Cc.check (db, AReln (r, map (simplify unifs) es))
-                andalso checkGoals goals unifs
+                (complete ();
+                 Cc.check (db, AReln (r, map (simplify unifs) es))
+                 andalso checkGoals goals unifs)
               | ACond _ :: _ => false
     in
         checkGoals goals IM.empty
     end
 
-fun useKeys () =
-    let
-        fun findKeys hyps =
-            case hyps of
-                [] => ()
-              | AReln (Sql tab, [r1]) :: hyps =>
-                (case SM.find (!tabs, tab) of
-                     NONE => findKeys hyps
-                   | SOME (_, []) => findKeys hyps
-                   | SOME (_, ks) =>
-                     let
-                         fun finder hyps =
-                             case hyps of
-                                 [] => ()
-                               | AReln (Sql tab', [r2]) :: hyps =>
-                                 (if tab' = tab andalso
-                                     List.exists (List.all (fn f =>
-                                                               let
-                                                                   val r =
-                                                                       Cc.check (db,
-                                                                                 AReln (Eq, [Proj (r1, f),
-                                                                                             Proj (r2, f)]))
-                                                               in
-                                                                   (*Print.prefaces "Fs"
-                                                                                    [("tab",
-                                                                                      Print.PD.string tab),
-                                                                                     ("r1",
-                                                                                      p_exp (Proj (r1, f))),
-                                                                                     ("r2",
-                                                                                      p_exp (Proj (r2, f))),
-                                                                                     ("r",
-                                                                                      Print.PD.string
-                                                                                          (Bool.toString r))];*)
-                                                                   r
-                                                               end)) ks then
-                                      ((*Print.prefaces "Key match" [("tab", Print.PD.string tab),
-                                                                     ("r1", p_exp r1),
-                                                                     ("r2", p_exp r2),
-                                                                     ("rp1", Cc.p_repOf cc r1),
-                                                                     ("rp2", Cc.p_repOf cc r2)];*)
-                                       Cc.assert (db, AReln (Eq, [r1, r2])))
-                                  else
-                                      ();
-                                  finder hyps)
-                               | _ :: hyps => finder hyps
-                     in
-                         finder hyps;
-                         findKeys hyps
-                     end)
-              | _ :: hyps => findKeys hyps
-
-        val (_, hs) = !hyps
-    in
-        (*print "findKeys\n";*)
-        findKeys hs
-    end
-
 fun buildable uk (e, loc) =
     let
         fun doPols pols acc =
@@ -1358,23 +1348,23 @@
                 checkGoals goals (fn unifs => doPols pols (map (simplify unifs) es @ acc))
                 orelse doPols pols acc
     in
-        useKeys ();
         if doPols (!sendable) [] then
             ()
         else
             let
-                val (_, hs) = !hyps
+                val (_, hs, _) = !hyps
             in
                 ErrorMsg.errorAt loc "The information flow policy may be violated here.";
                 Print.prefaces "Situation" [("Hypotheses", Print.p_list p_atom hs),
-                                            ("User learns", p_exp e),
-                                            ("E-graph", Cc.p_database db)]
+                                            ("User learns", p_exp e)(*,
+                                            ("E-graph", Cc.p_database db)*)]
             end
     end       
 
 fun checkPaths () =
     let
-        val hs = !hyps
+        val (n, hs, _) = !hyps
+        val hs = (n, hs)
     in
         app (fn r =>
                 case !r of
@@ -1391,6 +1381,7 @@
                    sendable := v :: !sendable)
 
 fun send uk (e, loc) = ((*Print.preface ("Send", p_exp e);*)
+                        complete ();
                         checkPaths ();
                         if isKnown e then
                             ()
@@ -1401,6 +1392,7 @@
     let
         val pols = !pols
     in
+        complete ();
         if List.exists (fn goals =>
                            if checkGoals goals (fn _ => true) then
                                ((*Print.prefaces "Match" [("goals", Print.p_list p_atom goals),
@@ -1413,7 +1405,7 @@
             ()
         else
             let
-                val (_, hs) = !hyps
+                val (_, hs, _) = !hyps
             in
                 ErrorMsg.errorAt loc "The database update policy may be violated here.";
                 Print.preface ("Hypotheses", Print.p_list p_atom hs)
@@ -1434,7 +1426,7 @@
 
 fun reset () = (Cc.clear db;
                 path := [];
-                hyps := (0, []);
+                hyps := (0, [], ref false);
                 nvar := 0;
                 sendable := [];
                 insertable := [];
@@ -1444,15 +1436,15 @@
 fun havocReln r =
     let
         val n = !hnames
-        val (_, hs) = !hyps
+        val (_, hs, _) = !hyps
     in
         hnames := n + 1;
-        hyps := (n, List.filter (fn AReln (r', _) => r' <> r | _ => true) hs)
+        hyps := (n, List.filter (fn AReln (r', _) => r' <> r | _ => true) hs, ref false)
     end
 
 fun debug () =
     let
-        val (_, hs) = !hyps
+        val (_, hs, _) = !hyps
     in
         Print.preface ("Hyps", Print.p_list p_atom hs)
     end