diff src/iflow.sml @ 1202:509a6d7b60fb

Iflow tested with positive and negative cases
author Adam Chlipala <adamc@hcoop.net>
date Sun, 04 Apr 2010 16:17:23 -0400
parents 8793fd48968c
children a75c66dd2aeb
line wrap: on
line diff
--- a/src/iflow.sml	Sun Apr 04 15:17:57 2010 -0400
+++ b/src/iflow.sml	Sun Apr 04 16:17:23 2010 -0400
@@ -29,6 +29,8 @@
 
 open Mono
 
+structure IM = IntBinaryMap
+
 structure SS = BinarySetFn(struct
                            type ord_key = string
                            val compare = String.compare
@@ -75,7 +77,7 @@
        | Select of int * lvar * lvar * prop * exp
 
 local
-    val count = ref 0
+    val count = ref 1
 in
 fun newLvar () =
     let
@@ -116,17 +118,6 @@
         sub
     end
 
-fun eq' (e1, e2) =
-    case (e1, e2) of
-        (Const p1, Const p2) => Prim.equal (p1, p2)
-      | (Var n1, Var n2) => n1 = n2
-      | (Lvar n1, Lvar n2) => n1 = n2
-      | (Func (f1, es1), Func (f2, es2)) => f1 = f2 andalso ListPair.allEq eq' (es1, es2)
-      | (Recd xes1, Recd xes2) => ListPair.allEq (fn ((x1, e1), (x2, e2)) => x1 = x2 andalso eq' (e1, e2)) (xes1, xes2)
-      | (Proj (e1, s1), Proj (e2, s2)) => eq' (e1, e2) andalso s1 = s2
-      | (Finish, Finish) => true
-      | _ => false
-
 fun isKnown e =
     case e of
         Const _ => true
@@ -174,14 +165,12 @@
                  Proj (e', s))
       | Finish => Finish
 
-fun eq (e1, e2) = eq' (simplify e1, simplify e2)
-
-fun decomp or =
+fun decomp fals or =
     let
         fun decomp p k =
             case p of
                 True => k []
-              | False => true
+              | False => fals
               | Unknown => k []
               | And (p1, p2) => 
                 decomp p1 (fn ps1 =>
@@ -195,22 +184,155 @@
         decomp
     end
 
-fun rimp ((r1 : reln, es1), (r2, es2)) =
-    r1 = r2 andalso ListPair.allEq eq (es1, es2)
+val unif = ref (IM.empty : exp IM.map)
 
-fun imp (p1, p2) =
-    decomp (fn (e1, e2) => e1 andalso e2 ()) p1
-    (fn hyps =>
-        decomp (fn (e1, e2) => e1 orelse e2 ()) p2
-        (fn goals =>
-            List.all (fn r2 => List.exists (fn r1 => rimp (r1, r2)) hyps) goals))
+fun lvarIn lv =
+    let
+        fun lvi e =
+            case e of
+                Const _ => false
+              | Var _ => false
+              | Lvar lv' => lv' = lv
+              | Func (_, es) => List.exists lvi es
+              | Recd xes => List.exists (lvi o #2) xes
+              | Proj (e, _) => lvi e
+              | Finish => false
+    in
+        lvi
+    end
+
+fun eq' (e1, e2) =
+    case (e1, e2) of
+        (Const p1, Const p2) => Prim.equal (p1, p2)
+      | (Var n1, Var n2) => n1 = n2
+
+      | (Lvar n1, _) =>
+        (case IM.find (!unif, n1) of
+             SOME e1 => eq' (e1, e2)
+           | NONE =>
+             case e2 of
+                 Lvar n2 =>
+                 (case IM.find (!unif, n2) of
+                      SOME e2 => eq' (e1, e2)
+                    | NONE => n1 = n2
+                              orelse (unif := IM.insert (!unif, n1, e2);
+                                      true))
+               | _ =>
+                 if lvarIn n1 e2 then
+                     false
+                 else
+                     (unif := IM.insert (!unif, n1, e2);
+                      true))
+
+      | (_, Lvar n2) =>
+        (case IM.find (!unif, n2) of
+             SOME e2 => eq' (e1, e2)
+           | NONE =>
+             if lvarIn n2 e1 then
+                 false
+             else
+                 (unif := IM.insert (!unif, n2, e1);
+                  true))
+                                       
+      | (Func (f1, es1), Func (f2, es2)) => f1 = f2 andalso ListPair.allEq eq' (es1, es2)
+      | (Recd xes1, Recd xes2) => ListPair.allEq (fn ((x1, e1), (x2, e2)) => x1 = x2 andalso eq' (e1, e2)) (xes1, xes2)
+      | (Proj (e1, s1), Proj (e2, s2)) => eq' (e1, e2) andalso s1 = s2
+      | (Finish, Finish) => true
+      | _ => false
+
+fun eq (e1, e2) =
+    let
+        val saved = !unif
+    in
+        if eq' (simplify e1, simplify e2) then
+            true
+        else
+            (unif := saved;
+             false)
+    end
+
+exception Imply of prop * prop
+
+fun rimp ((r1, es1), (r2, es2)) =
+    case (r1, r2) of
+        (Sql r1', Sql r2') =>
+        r1' = r2' andalso
+        (case (es1, es2) of
+             ([Recd xes1], [Recd xes2]) =>
+             let
+                 val saved = !unif
+             in
+                 (*print ("Go: " ^ r1' ^ "\n");*)
+                 (*raise Imply (Reln (r1, es1), Reln (r2, es2));*)
+                 if List.all (fn (f, e2) =>
+                                 List.exists (fn (f', e1) =>
+                                                 f' = f andalso eq (e1, e2)) xes1) xes2 then
+                     true
+                 else
+                     (unif := saved;
+                      false)
+             end
+           | _ => false)
+      | (Eq, Eq) =>
+        (case (es1, es2) of
+             ([x1, y1], [x2, y2]) =>
+             let
+                 val saved = !unif
+             in
+                 if eq (x1, x2) andalso eq (y1, y2) then
+                     true
+                 else
+                     (unif := saved;
+                      (*raise Imply (Reln (Eq, es1), Reln (Eq, es2));*)
+                      eq (x1, y2) andalso eq (y1, x2))
+             end
+           | _ => false)
+      | _ => false
+
+fun imply (p1, p2) =
+    (unif := IM.empty;
+     (*raise (Imply (p1, p2));*)
+     decomp true (fn (e1, e2) => e1 andalso e2 ()) p1
+            (fn hyps =>
+                decomp false (fn (e1, e2) => e1 orelse e2 ()) p2
+                (fn goals =>
+                    let
+                        fun gls goals onFail =
+                            case goals of
+                                [] => true
+                              | g :: goals =>
+                                let
+                                    fun hps hyps =
+                                        case hyps of
+                                            [] => onFail ()
+                                          | h :: hyps =>
+                                            let
+                                                val saved = !unif
+                                            in
+                                                if rimp (h, g) then
+                                                    let
+                                                        val changed = IM.numItems (!unif) = IM.numItems saved
+                                                    in
+                                                        gls goals (fn () => (unif := saved;
+                                                                             changed andalso hps hyps))
+                                                    end
+                                                else
+                                                    hps hyps
+                                            end
+                                in
+                                    hps hyps
+                                end
+                    in
+                        gls goals (fn () => false)
+                    end)))
+
 
 fun patCon pc =
     case pc of
         PConVar n => "C" ^ Int.toString n
       | PConFfi {mod = m, datatyp = d, con = c, ...} => m ^ "." ^ d ^ "." ^ c
 
-exception Summaries of (string * exp * prop * (exp * prop) list) list
+
 
 datatype chunk =
          String of string
@@ -226,8 +348,8 @@
 
 fun always v chs = SOME (v, chs)
 
-fun parse p chs =
-    case p chs of
+fun parse p s =
+    case p (chunkify s) of
         SOME (v, []) => SOME v
       | _ => NONE
 
@@ -325,21 +447,33 @@
 val query = wrap (follow select from)
             (fn (fs, ts) => {Select = fs, From = ts})
 
-fun queryProp rv e =
-    case parse query (chunkify e) of
+fun queryProp rv oe e =
+    case parse query e of
         NONE => Unknown
       | SOME r =>
-        foldl (fn ((t, v), p) =>
-                  And (p,
-                       Reln (Sql t,
-                             [Recd (foldl (fn ((v', f), fs) =>
-                                              if v' = v then
-                                                  (f, Proj (Proj (Lvar rv, v), f)) :: fs
-                                             else
-                                                 fs) [] (#Select r))])))
-              True (#From r)
+        let
+            val p =
+                foldl (fn ((t, v), p) =>
+                          And (p,
+                               Reln (Sql t,
+                                     [Recd (foldl (fn ((v', f), fs) =>
+                                                      if v' = v then
+                                                          (f, Proj (Proj (Lvar rv, v), f)) :: fs
+                                                      else
+                                                          fs) [] (#Select r))])))
+                      True (#From r)
+        in
+            case oe of
+                NONE => p
+              | SOME oe =>
+                And (p,
+                     foldl (fn ((v, f), p) =>
+                               Or (p,
+                                   Reln (Eq, [oe, Proj (Proj (Lvar rv, v), f)])))
+                           False (#Select r))
+        end
 
-fun evalExp env (e : Mono.exp, st as (nv, p, sent)) =
+fun evalExp env (e as (_, loc), st as (nv, p, sent)) =
     let
         fun default () =
             (Var nv, (nv+1, p, sent))
@@ -348,7 +482,7 @@
             if isKnown e then
                 sent
             else
-                (e, p) :: sent
+                (loc, e, p) :: sent
     in
         case #1 e of
             EPrim p => (Const p, st)
@@ -476,7 +610,7 @@
 
                 val r' = newLvar ()
                 val acc' = newLvar ()
-                val qp = queryProp r' q
+                val qp = queryProp r' NONE q
 
                 val doSubExp = subExp (r, r') o subExp (acc, acc')
                 val doSubProp = subProp (r, r') o subProp (acc, acc')
@@ -485,7 +619,9 @@
                 val p = And (p, qp)
                 val p = Select (r, r', acc', p, doSubExp b)
             in
-                (Var r, (#1 st + 1, And (#2 st, p), map (fn (e, p) => (doSubExp e, And (qp, doSubProp p))) (#3 st')))
+                (Var r, (#1 st + 1, And (#2 st, p), map (fn (loc, e, p) => (loc,
+                                                                            doSubExp e,
+                                                                            And (qp, doSubProp p))) (#3 st')))
             end
           | EDml _ => default ()
           | ENextval _ => default ()
@@ -504,7 +640,7 @@
 
 fun check file =
     let
-        fun decl ((d, _), summaries) =
+        fun decl ((d, _), (vals, pols)) =
             case d of
                 DVal (x, _, _, e, _) =>
                 let
@@ -513,15 +649,31 @@
                             EAbs (_, _, _, e) => deAbs (e, Var nv :: env, nv + 1)
                           | _ => (e, env, nv)
 
-                    val (e, env, nv) = deAbs (e, [], 0)
+                    val (e, env, nv) = deAbs (e, [], 1)
 
                     val (e, (_, p, sent)) = evalExp env (e, (nv, True, []))
                 in
-                    (x, e, p, sent) :: summaries
+                    ((x, e, p, sent) :: vals, pols)
                 end
-              | _ => summaries
+
+              | DPolicy (PolQuery e) => (vals, queryProp 0 (SOME (Var 0)) e :: pols)
+
+              | _ => (vals, pols)
+
+        val () = unif := IM.empty
+
+        val (vals, pols) = foldl decl ([], []) file
     in
-        raise Summaries (foldl decl [] file)
+        app (fn (name, _, _, sent) =>
+                app (fn (loc, e, p) =>
+                        let
+                            val p = And (p, Reln (Eq, [Var 0, e]))
+                        in
+                            if List.exists (fn pol => imply (p, pol)) pols then
+                                ()
+                            else
+                                ErrorMsg.errorAt loc "The information flow policy may be violated here."
+                        end) sent) vals
     end
 
 end