diff src/iflow.sml @ 1208:b5a4c5407ae0

Checking known() correctly, according to a pair of examples
author Adam Chlipala <adamc@hcoop.net>
date Tue, 06 Apr 2010 10:39:15 -0400
parents ae3036773768
children 775357041e48
line wrap: on
line diff
--- a/src/iflow.sml	Tue Apr 06 09:51:36 2010 -0400
+++ b/src/iflow.sml	Tue Apr 06 10:39:15 2010 -0400
@@ -212,11 +212,20 @@
         Finish => true
       | _ => false
 
+val unif = ref (IM.empty : exp IM.map)
+
+fun reset () = unif := IM.empty
+fun save () = !unif
+fun restore x = unif := x
+
 fun simplify e =
     case e of
         Const _ => e
       | Var _ => e
-      | Lvar _ => e
+      | Lvar n =>
+        (case IM.find (!unif, n) of
+             NONE => e
+           | SOME e => simplify e)
       | Func (f, es) =>
         let
             val es = map simplify es
@@ -265,12 +274,6 @@
         decomp
     end
 
-val unif = ref (IM.empty : exp IM.map)
-
-fun reset () = unif := IM.empty
-fun save () = !unif
-fun restore x = unif := x
-
 fun lvarIn lv =
     let
         fun lvi e =
@@ -300,7 +303,7 @@
                  (case IM.find (!unif, n2) of
                       SOME e2 => eq' (e1, e2)
                     | NONE => n1 = n2
-                              orelse (unif := IM.insert (!unif, n1, e2);
+                              orelse (unif := IM.insert (!unif, n2, e1);
                                       true))
                | _ =>
                  if lvarIn n1 e2 then
@@ -338,7 +341,85 @@
 
 exception Imply of prop * prop
 
-fun rimp ((r1, es1), (r2, es2)) =
+val debug = ref false
+
+(* Congruence closure *)
+structure Cc :> sig
+    type t
+    val empty : t
+    val assert : t * exp * exp -> t
+    val query : t * exp * exp -> bool
+    val allPeers : t * exp -> exp list
+end = struct
+
+fun eq' (e1, e2) =
+    case (e1, e2) of
+        (Const p1, Const p2) => Prim.equal (p1, p2)
+      | (Var n1, Var n2) => n1 = n2
+      | (Lvar n1, Lvar n2) => n1 = n2
+      | (Func (f1, es1), Func (f2, es2)) => f1 = f2 andalso ListPair.allEq eq' (es1, es2)
+      | (Recd xes1, Recd xes2) => length xes1 = length xes2 andalso
+                                  List.all (fn (x2, e2) =>
+                                               List.exists (fn (x1, e1) => x1 = x2 andalso eq' (e1, e2)) xes2) xes1
+      | (Proj (e1, x1), Proj (e2, x2)) => eq' (e1, e2) andalso x1 = x2
+      | (Finish, Finish) => true
+      | _ => false
+
+fun eq (e1, e2) = eq' (simplify e1, simplify e2)
+
+type t = (exp * exp) list
+
+val empty = []
+
+fun lookup (t, e) =
+    case List.find (fn (e', _) => eq (e', e)) t of
+        NONE => e
+      | SOME (_, e2) => lookup (t, e2)
+
+fun assert (t, e1, e2) =
+    let
+        val r1 = lookup (t, e1)
+        val r2 = lookup (t, e2)
+    in
+        if eq (r1, r2) then
+            t
+        else
+            (r1, r2) :: t
+    end
+
+open Print
+
+fun query (t, e1, e2) =
+    (if !debug then
+         prefaces "CC query" [("e1", p_exp (simplify e1)),
+                              ("e2", p_exp (simplify e2)),
+                              ("t", p_list (fn (e1, e2) => box [p_exp (simplify e1),
+                                                                space,
+                                                                PD.string "->",
+                                                                space,
+                                                                p_exp (simplify e2)]) t)]
+     else
+         ();
+     eq (lookup (t, e1), lookup (t, e2)))
+
+fun allPeers (t, e) =
+    let
+        val r = lookup (t, e)
+    in
+        r :: List.mapPartial (fn (e1, e2) =>
+                                 let
+                                     val r' = lookup (t, e2)
+                                 in
+                                     if eq (r, r') then
+                                         SOME e1
+                                     else
+                                         NONE
+                                 end) t
+    end
+
+end
+
+fun rimp cc ((r1, es1), (r2, es2)) =
     case (r1, r2) of
         (Sql r1', Sql r2') =>
         r1' = r2' andalso
@@ -367,62 +448,81 @@
                      true
                  else
                      (restore saved;
-                      (*raise Imply (Reln (Eq, es1), Reln (Eq, es2));*)
-                      eq (x1, y2) andalso eq (y1, x2))
+                      if eq (x1, y2) andalso eq (y1, x2) then
+                          true
+                      else
+                          (restore saved;
+                           false))
              end
            | _ => false)
       | (Known, Known) =>
         (case (es1, es2) of
-             ([e1], [e2]) =>
+             ([Var v], [e2]) =>
              let
-                 fun walk e2 =
-                     eq (e1, e2) orelse
-                     case e2 of
-                         Proj (e2, _) => walk e2
+                 fun matches e =
+                     case e of
+                         Var v' => v' = v
+                       | Proj (e, _) => matches e
                        | _ => false
              in
-                 walk e2
+                 List.exists matches (Cc.allPeers (cc, e2))
              end
            | _ => false)
       | _ => false
 
 fun imply (p1, p2) =
-    (reset ();
-     (*raise (Imply (p1, p2));*)
-     decomp true (fn (e1, e2) => e1 andalso e2 ()) p1
-            (fn hyps =>
-                decomp false (fn (e1, e2) => e1 orelse e2 ()) p2
-                (fn goals =>
-                    let
-                        fun gls goals onFail =
-                            case goals of
-                                [] => true
-                              | g :: goals =>
-                                let
-                                    fun hps hyps =
-                                        case hyps of
-                                            [] => onFail ()
-                                          | h :: hyps =>
-                                            let
-                                                val saved = save ()
-                                            in
-                                                if rimp (h, g) then
-                                                    let
-                                                        val changed = IM.numItems (!unif) <> IM.numItems saved
-                                                    in
-                                                        gls goals (fn () => (restore saved;
-                                                                             changed andalso hps hyps))
-                                                    end
-                                                else
-                                                    hps hyps
-                                            end
-                                in
-                                    hps hyps
-                                end
-                    in
-                        gls goals (fn () => false)
-                    end)))
+    let
+        fun doOne doKnown =
+            decomp true (fn (e1, e2) => e1 andalso e2 ()) p1
+                   (fn hyps =>
+                       decomp false (fn (e1, e2) => e1 orelse e2 ()) p2
+                              (fn goals =>
+                                  let
+                                      val cc = foldl (fn (p, cc) =>
+                                                         case p of
+                                                             (Eq, [e1, e2]) => Cc.assert (cc, e1, e2)
+                                                           | _ => cc) Cc.empty hyps
 
+                                      fun gls goals onFail =
+                                          case goals of
+                                              [] => true
+                                            | g :: goals =>
+                                              case (doKnown, g) of
+                                                  (false, (Known, _)) => gls goals onFail
+                                                | _ =>
+                                                  let
+                                                      fun hps hyps =
+                                                          case hyps of
+                                                              [] => onFail ()
+                                                            | h :: hyps =>
+                                                              let
+                                                                  val saved = save ()
+                                                              in
+                                                                  if rimp cc (h, g) then
+                                                                      let
+                                                                          val changed = IM.numItems (!unif)
+                                                                                        <> IM.numItems saved
+                                                                      in
+                                                                          gls goals (fn () => (restore saved;
+                                                                                               changed andalso hps hyps))
+                                                                      end
+                                                                  else
+                                                                      hps hyps
+                                                              end
+                                                  in
+                                                      (case g of
+                                                           (Eq, [e1, e2]) => Cc.query (cc, e1, e2)
+                                                         | _ => false)
+                                                      orelse hps hyps
+                                                  end
+                                  in
+                                      gls goals (fn () => false)
+                                  end))
+    in
+        reset ();
+        doOne false;
+        doOne true
+    end
 
 fun patCon pc =
     case pc of
@@ -531,8 +631,6 @@
 fun ws p = wrap (follow (skip (fn ch => ch = #" "))
                         (follow p (skip (fn ch => ch = #" ")))) (#1 o #2)
 
-val debug = ref false
-
 fun log name p chs =
     (if !debug then
          case chs of
@@ -924,7 +1022,16 @@
                 let
                     val p = And (p, Reln (Eq, [Var 0, e]))
                 in
-                    if List.exists (fn pol => imply (p, pol)) pols then
+                    if List.exists (fn pol => if imply (p, pol) then
+                                                  (if !debug then
+                                                       Print.prefaces "Match"
+                                                       [("Hyp", p_prop p),
+                                                        ("Goal", p_prop pol)]
+                                                   else
+                                                       ();
+                                                   true)
+                                              else
+                                                  false) pols then
                         ()
                     else
                         (ErrorMsg.errorAt loc "The information flow policy may be violated here.";