diff src/iflow.sml @ 1218:48d2ca496d2c

Path conditions, used to track implicit flows
author Adam Chlipala <adamc@hcoop.net>
date Sat, 10 Apr 2010 13:02:15 -0400
parents 4d206e603300
children 3224faec752d
line wrap: on
line diff
--- a/src/iflow.sml	Sat Apr 10 10:24:13 2010 -0400
+++ b/src/iflow.sml	Sat Apr 10 13:02:15 2010 -0400
@@ -380,18 +380,18 @@
 (* Congruence closure *)
 structure Cc :> sig
     type database
-    type representative
 
     exception Contradiction
     exception Undetermined
 
     val database : unit -> database
-    val representative : database * exp -> representative
 
     val assert : database * atom -> unit
     val check : database * atom -> bool
 
     val p_database : database Print.printer
+
+    val builtFrom : database * {Base : exp list, Derived : exp} -> bool
 end = struct
 
 exception Contradiction
@@ -420,7 +420,7 @@
 val finish = ref (Node {Rep = ref NONE,
                         Cons = ref SM.empty,
                         Variety = VFinish,
-                        Known = ref false})
+                        Known = ref true})
 
 type database = {Vars : representative IM.map ref,
                  Consts : representative CM.map ref,
@@ -470,7 +470,12 @@
                                                space,
                                                string "=",
                                                space,
-                                               p_rep n]) (IM.listItemsi (!(#Vars db)))]
+                                               p_rep n,
+                                               if !(#Known (unNode n)) then
+                                                   box [space,
+                                                        string "(known)"]
+                                               else
+                                                   box []]) (IM.listItemsi (!(#Vars db)))]
 
 fun repOf (n : representative) : representative =
     case !(#Rep (unNode n)) of
@@ -484,11 +489,15 @@
         end
 
 fun markKnown r =
-    (#Known (unNode r) := true;
-     case #Variety (unNode r) of
-         Dt1 (_, r) => markKnown r
-       | Recrd xes => SM.app markKnown (!xes)
-       | _ => ())
+    if !(#Known (unNode r)) then
+        ()
+    else
+        (#Known (unNode r) := true;
+         SM.app markKnown (!(#Cons (unNode r)));
+         case #Variety (unNode r) of
+             Dt1 (_, r) => markKnown r
+           | Recrd xes => SM.app markKnown (!xes)
+           | _ => ())
 
 fun representative (db : database, e) =
     let
@@ -529,7 +538,7 @@
                                                 val r = ref (Node {Rep = ref NONE,
                                                                    Cons = ref SM.empty,
                                                                    Variety = Dt0 f,
-                                                                   Known = ref false})
+                                                                   Known = ref true})
                                             in
                                                 #Con0s db := SM.insert (!(#Con0s db), f, r);
                                                 r
@@ -747,24 +756,23 @@
                                  unif (!xes2, !xes1)
                              end
                            | (VFinish, VFinish) => ()
-                           | (Nothing, _) =>
-                             (#Rep (unNode r1) := SOME r2;
-                              if !(#Known (unNode r1)) andalso not (!(#Known (unNode r2))) then
-                                  markKnown r2
-                              else
-                                  ();
-                              #Cons (unNode r2) := SM.unionWith #1 (!(#Cons (unNode r2)), !(#Cons (unNode r1)));
-                              compactFuncs ())
-                           | (_, Nothing) =>
-                             (#Rep (unNode r2) := SOME r1;
-                              if !(#Known (unNode r2)) andalso not (!(#Known (unNode r1))) then
-                                  markKnown r1
-                              else
-                                  ();
-                              #Cons (unNode r1) := SM.unionWith #1 (!(#Cons (unNode r1)), !(#Cons (unNode r2)));
-                              compactFuncs ())
+                           | (Nothing, _) => mergeNodes (r1, r2)
+                           | (_, Nothing) => mergeNodes (r2, r1)
                            | _ => raise Contradiction
 
+                and mergeNodes (r1, r2) =
+                    (#Rep (unNode r1) := SOME r2;
+                     if !(#Known (unNode r1)) then
+                         markKnown r2
+                     else
+                         ();
+                     if !(#Known (unNode r2)) then
+                         markKnown r1
+                     else
+                         ();
+                     #Cons (unNode r2) := SM.unionWith #1 (!(#Cons (unNode r2)), !(#Cons (unNode r1)));
+                     compactFuncs ())
+
                 and compactFuncs () =
                     let
                         fun loop funcs =
@@ -815,6 +823,27 @@
             end
           | _ => false
 
+fun builtFrom (db, {Base = bs, Derived = d}) =
+    let
+        val bs = map (fn b => representative (db, b)) bs
+
+        fun loop d =
+            let
+                val d = repOf d
+            in
+                List.exists (fn b => repOf b = d) bs
+                orelse case #Variety (unNode d) of
+                           Dt0 _ => true
+                         | Dt1 (_, d) => loop d
+                         | Prim _ => true
+                         | Recrd xes => List.all loop (SM.listItems (!xes))
+                         | VFinish => true
+                         | Nothing => false
+            end
+    in
+        loop (representative (db, d))
+    end
+
 end
 
 fun decomp fals or =
@@ -836,67 +865,66 @@
         decomp
     end
 
-fun imply (p1, p2) =
-    decomp true (fn (e1, e2) => e1 andalso e2 ()) p1
-           (fn hyps =>
-               decomp false (fn (e1, e2) => e1 orelse e2 ()) p2
-                      (fn goals =>
-                          let
-                              fun gls goals onFail acc =
-                                  case goals of
-                                      [] =>
-                                      (let
-                                           val cc = Cc.database ()
-                                           val () = app (fn a => Cc.assert (cc, a)) hyps
-                                       in
-                                           (List.all (fn a =>
-                                                         if Cc.check (cc, a) then
-                                                             true
-                                                         else
-                                                             ((*Print.prefaces "Can't prove"
-                                                                             [("a", p_atom a),
-                                                                              ("hyps", Print.p_list p_atom hyps),
-                                                                              ("db", Cc.p_database cc)];*)
-                                                              false)) acc)
-                                           handle Cc.Contradiction => false
-                                       end handle Cc.Undetermined => false)
-                                      orelse onFail ()
-                                    | (g as AReln (Sql gf, [ge])) :: goals =>
-                                      let
-                                          fun hps hyps =
-                                              case hyps of
-                                                  [] => gls goals onFail (g :: acc)
-                                                | (h as AReln (Sql hf, [he])) :: hyps =>
-                                                  if gf = hf then
-                                                      let
-                                                          val saved = save ()
-                                                      in
-                                                          if eq (ge, he) then
-                                                              let
-                                                                  val changed = IM.numItems (!unif)
-                                                                                <> IM.numItems saved
-                                                              in
-                                                                  gls goals (fn () => (restore saved;
-                                                                                       changed
-                                                                                       andalso hps hyps))
-                                                                      acc
-                                                              end
-                                                          else
-                                                              hps hyps
-                                                      end
-                                                  else
-                                                      hps hyps
-                                                | _ :: hyps => hps hyps 
-                                      in
-                                          hps hyps
-                                      end
-                                    | g :: goals => gls goals onFail (g :: acc)
-                          in
-                              reset ();
-                              (*Print.prefaces "Big go" [("hyps", Print.p_list p_atom hyps),
-                                                       ("goals", Print.p_list p_atom goals)];*)
-                              gls goals (fn () => false) []
-                          end handle Cc.Contradiction => true))
+fun imply (hyps, goals, outs) =
+    let
+        fun gls goals onFail acc =
+            case goals of
+                [] =>
+                (let
+                     val cc = Cc.database ()
+                     val () = app (fn a => Cc.assert (cc, a)) hyps
+                 in
+                     (List.all (fn a =>
+                                   if Cc.check (cc, a) then
+                                       true
+                                   else
+                                       ((*Print.prefaces "Can't prove"
+                                                       [("a", p_atom a),
+                                                        ("hyps", Print.p_list p_atom hyps),
+                                                        ("db", Cc.p_database cc)];*)
+                                        false)) acc
+                      (*andalso (Print.preface ("Finding", Cc.p_database cc); true)*)
+                      andalso Cc.builtFrom (cc, {Derived = Var 0,
+                                                 Base = outs}))
+                     handle Cc.Contradiction => false
+                 end handle Cc.Undetermined => false)
+                orelse onFail ()
+              | (g as AReln (Sql gf, [ge])) :: goals =>
+                let
+                    fun hps hyps =
+                        case hyps of
+                            [] => gls goals onFail (g :: acc)
+                          | (h as AReln (Sql hf, [he])) :: hyps =>
+                            if gf = hf then
+                                let
+                                    val saved = save ()
+                                in
+                                    if eq (ge, he) then
+                                        let
+                                            val changed = IM.numItems (!unif)
+                                                          <> IM.numItems saved
+                                        in
+                                            gls goals (fn () => (restore saved;
+                                                                 changed
+                                                                 andalso hps hyps))
+                                                acc
+                                        end
+                                    else
+                                        hps hyps
+                                end
+                            else
+                                hps hyps
+                          | _ :: hyps => hps hyps 
+                in
+                    hps hyps
+                end
+              | g :: goals => gls goals onFail (g :: acc)
+    in
+        reset ();
+        (*Print.prefaces "Big go" [("hyps", Print.p_list p_atom hyps),
+                                   ("goals", Print.p_list p_atom goals)];*)
+        gls goals (fn () => false) []
+    end handle Cc.Contradiction => true
 
 fun patCon pc =
     case pc of
@@ -1204,7 +1232,7 @@
         end
 
 datatype queryMode =
-         SomeCol of exp
+         SomeCol
        | AllCols of exp
 
 exception Default
@@ -1213,7 +1241,7 @@
     let
         fun default () = (print ("Warning: Information flow checker can't parse SQL query at "
                                  ^ ErrorMsg.spanToString (#2 e) ^ "\n");
-                          (rvN, Unknown, Unknown, []))
+                          (rvN, Unknown, Unknown, [], []))
     in
         case parse query e of
             NONE => default ()
@@ -1281,57 +1309,66 @@
                               | _ => p
 
                 fun normal () =
-                    (And (p, case oe of
-                                 SomeCol oe =>
-                                 foldl (fn (si, p) =>
-                                           let
-                                               val p' = case si of
-                                                            SqField (v, f) => Reln (Eq, [oe, Proj (rvOf v, f)])
-                                                          | SqExp (e, f) =>
-                                                            case expIn e of
-                                                                inr _ => Unknown
-                                                              | inl e => Reln (Eq, [oe, e])
-                                           in
-                                               Or (p, p')
-                                           end)
-                                       False (#Select r)
-                               | AllCols oe =>
-                                 foldl (fn (si, p) =>
-                                           let
-                                               val p' = case si of
-                                                            SqField (v, f) => Reln (Eq, [Proj (Proj (oe, v), f),
-                                                                                         Proj (rvOf v, f)])
-                                                          | SqExp (e, f) =>
-                                                            case expIn e of
-                                                                inr p => Cond (Proj (oe, f), p)
-                                                              | inl e => Reln (Eq, [Proj (oe, f), e])
-                                           in
+                    case oe of
+                        SomeCol =>
+                        (rvN, p, True,
+                         List.mapPartial (fn si =>
+                                             case si of
+                                                 SqField (v, f) => SOME (Proj (rvOf v, f))
+                                               | SqExp (e, f) =>
+                                                 case expIn e of
+                                                     inr _ => NONE
+                                                   | inl e => SOME e) (#Select r))
+                      | AllCols oe =>
+                        (rvN, And (p, foldl (fn (si, p) =>
+                                                let
+                                                    val p' = case si of
+                                                                 SqField (v, f) => Reln (Eq, [Proj (Proj (oe, v), f),
+                                                                                              Proj (rvOf v, f)])
+                                                               | SqExp (e, f) =>
+                                                                 case expIn e of
+                                                                     inr p => Cond (Proj (oe, f), p)
+                                                                   | inl e => Reln (Eq, [Proj (oe, f), e])
+                                                in
                                                And (p, p')
-                                           end)
-                                       True (#Select r)),
-                     True)
+                                                end)
+                                            True (#Select r)),
+                         True, [])
 
-                val (p, wp) =
+                val (rvN, p, wp, outs) =
                     case #Select r of
                         [SqExp (Binop (Exps bo, Count, SqConst (Prim.Int 0)), f)] =>
                         (case bo (Const (Prim.Int 1), Const (Prim.Int 2)) of
                              Reln (Gt, [Const (Prim.Int 1), Const (Prim.Int 2)]) =>
-                             let
-                                 val oe = case oe of
-                                              SomeCol oe => oe
-                                            | AllCols oe => Proj (oe, f)
-                             in
-                                 (Or (Reln (Eq, [oe, Func (DtCon0 "Basis.bool.False", [])]),
-                                      And (Reln (Eq, [oe, Func (DtCon0 "Basis.bool.True", [])]),
-                                           p)),
-                                  Reln (Eq, [oe, Func (DtCon0 "Basis.bool.True", [])]))
-                             end
+                             (case oe of
+                                  SomeCol =>
+                                  let
+                                      val (rvN, oe) = rv rvN
+                                  in
+                                      (rvN,
+                                       Or (Reln (Eq, [oe, Func (DtCon0 "Basis.bool.False", [])]),
+                                           And (Reln (Eq, [oe, Func (DtCon0 "Basis.bool.True", [])]),
+                                                p)),
+                                       Reln (Eq, [oe, Func (DtCon0 "Basis.bool.True", [])]),
+                                       [oe])
+                                  end
+                                | AllCols oe =>
+                                  let
+                                      val oe = Proj (oe, f)
+                                  in
+                                      (rvN,
+                                       Or (Reln (Eq, [oe, Func (DtCon0 "Basis.bool.False", [])]),
+                                           And (Reln (Eq, [oe, Func (DtCon0 "Basis.bool.True", [])]),
+                                                p)),
+                                       Reln (Eq, [oe, Func (DtCon0 "Basis.bool.True", [])]),
+                                       [])
+                                  end)
                            | _ => normal ())
                       | _ => normal ()
             in
                 (rvN, p, wp, case #Where r of
                                  NONE => []
-                               | SOME e => map (fn (v, f) => Proj (rvOf v, f)) (usedFields e))
+                               | SOME e => map (fn (v, f) => Proj (rvOf v, f)) (usedFields e), outs)
             end
             handle Default => default ()
     end
@@ -1388,6 +1425,10 @@
         rr
     end
 
+datatype cflow = Case | Where
+datatype flow = Data | Control of cflow
+type check = ErrorMsg.span * exp * prop
+
 structure St :> sig
     type t
     val create : {Var : int,
@@ -1399,22 +1440,21 @@
     val ambient : t -> prop
     val setAmbient : t * prop -> t
 
-    type check = ErrorMsg.span * exp * prop
+    val paths : t -> (check * cflow) list
+    val addPath : t * (check * cflow) -> t
+    val addPaths : t * (check * cflow) list -> t
+    val clearPaths : t -> t
+    val setPaths : t * (check * cflow) list -> t
 
-    val path : t -> check list
-    val addPath : t * check -> t
-
-    val sent : t -> check list
-    val addSent : t * check -> t
-    val setSent : t * check list -> t
+    val sent : t -> (check * flow) list
+    val addSent : t * (check * flow) -> t
+    val setSent : t * (check * flow) list -> t
 end = struct
 
-type check = ErrorMsg.span * exp * prop
-
 type t = {Var : int,
           Ambient : prop,
-          Path : check list,
-          Sent : check list}
+          Path : (check * cflow) list,
+          Sent : (check * flow) list}
 
 fun create {Var = v, Ambient = p} = {Var = v,
                                      Ambient = p,
@@ -1433,11 +1473,23 @@
                              Path = #Path t,
                              Sent = #Sent t}
 
-fun path (t : t) = #Path t
+fun paths (t : t) = #Path t
 fun addPath (t : t, c) = {Var = #Var t,
                           Ambient = #Ambient t,
                           Path = c :: #Path t,
                           Sent = #Sent t}
+fun addPaths (t : t, cs) = {Var = #Var t,
+                            Ambient = #Ambient t,
+                            Path = cs @ #Path t,
+                            Sent = #Sent t}
+fun clearPaths (t : t) = {Var = #Var t,
+                          Ambient = #Ambient t,
+                          Path = [],
+                          Sent = #Sent t}
+fun setPaths (t : t, cs) = {Var = #Var t,
+                            Ambient = #Ambient t,
+                            Path = cs,
+                            Sent = #Sent t}
 
 fun sent (t : t) = #Sent t
 fun addSent (t : t, c) = {Var = #Var t,
@@ -1461,10 +1513,16 @@
             end
 
         fun addSent (p, e, st) =
-            if isKnown e then
-                st
-            else
-                St.addSent (st, (loc, e, p))
+            let
+                val st = if isKnown e then
+                             st
+                         else
+                             St.addSent (st, ((loc, e, p), Data))
+
+                val st = foldl (fn ((c, fl), st) => St.addSent (st, (c, Control fl))) st (St.paths st)
+            in
+                St.clearPaths st
+            end
     in
         case #1 e of
             EPrim p => (Const p, st)
@@ -1542,38 +1600,31 @@
             in
                 (Proj (e, s), st)
             end
-          | ECase (e, pes, _) =>
+          | ECase (e, pes, {result = res, ...}) =>
             let
                 val (e, st) = evalExp env (e, st)
                 val (st, r) = St.nextVar st
                 val orig = St.ambient st
+                val origPaths = St.paths st
 
-                val st = foldl (fn ((pt, pe), st) =>
-                                   let
-                                       val (env, pp) = evalPat env e pt
-                                       val (pe, st') = evalExp env (pe, St.setAmbient (st, And (orig, pp)))
-                                       (*val () = Print.prefaces "Case" [("loc", Print.PD.string
-                                                                                   (ErrorMsg.spanToString (#2 pt))),
-                                                                       ("env", Print.p_list p_exp env),
-                                                                       ("sent", Print.p_list_sep Print.PD.newline
-                                                                                (fn (loc, e, p) =>
-                                                                                    Print.box [Print.PD.string
-                                                                                               (ErrorMsg.spanToString loc),
-                                                                                               Print.PD.string ":",
-                                                                                               Print.space,
-                                                                                               p_exp e,
-                                                                                               Print.space,
-                                                                                               Print.PD.string "in",
-                                                                                               Print.space,
-                                                                                               p_prop p])
-                                                                                (List.take (#3 st', length (#3 st')
-                                                                                                    - length (#3 st))))]*)
-                                                       
-                                       val this = And (removeRedundant orig (St.ambient st'),
-                                                       Reln (Eq, [Var r, pe]))
-                                   in
-                                       St.setAmbient (st', Or (St.ambient st, this))
-                                   end) (St.setAmbient (st, False)) pes
+                val st = St.addPath (st, ((loc, e, orig), Case))
+
+                val (st, paths) =
+                    foldl (fn ((pt, pe), (st, paths)) =>
+                              let
+                                  val (env, pp) = evalPat env e pt
+                                  val (pe, st') = evalExp env (pe, St.setAmbient (st, And (orig, pp)))
+                                                  
+                                  val this = And (removeRedundant orig (St.ambient st'),
+                                                  Reln (Eq, [Var r, pe]))
+                              in
+                                  (St.setPaths (St.setAmbient (st', Or (St.ambient st, this)), origPaths),
+                                   St.paths st' @ paths)
+                              end) (St.setAmbient (st, False), []) pes
+
+                val st = case #1 res of
+                             TRecord [] => St.setPaths (st, origPaths)
+                           | _ => St.setPaths (st, paths)
             in
                 (Var r, St.setAmbient (st, And (orig, St.ambient st)))
             end
@@ -1633,7 +1684,7 @@
 
                 val (b, st') = evalExp (Var acc :: Var r :: env) (b, st')
 
-                val (st', qp, qwp, used) =
+                val (st', qp, qwp, used, _) =
                     queryProp env
                               st' (fn st' =>
                                       let
@@ -1662,12 +1713,12 @@
                                         (St.setAmbient (st, p), Var out)
                                     end
 
-                val sent = map (fn (loc, e, p) => (loc, e, And (qp, p))) (St.sent st')
+                val sent = map (fn ((loc, e, p), fl) => ((loc, e, And (qp, p)), fl)) (St.sent st')
 
                 val p' = And (p', qwp)
-                val sent = map (fn e => (loc, e, p')) used @ sent
+                val paths = map (fn e => ((loc, e, p'), Where)) used
             in
-                (res, St.setSent (st, sent))
+                (res, St.addPaths (St.setSent (st, sent), paths))
             end
           | EDml _ => default ()
           | ENextval _ => default ()
@@ -1728,8 +1779,12 @@
                     (St.sent st @ vals, pols)
                 end
 
-              | DPolicy (PolClient e) => (vals, #2 (queryProp [] 0 (fn rvN => (rvN + 1, Lvar rvN))
-                                                              (SomeCol (Var 0)) e) :: pols)
+              | DPolicy (PolClient e) =>
+                let
+                    val (_, p, _, _, outs) = queryProp [] 0 (fn rvN => (rvN + 1, Lvar rvN)) SomeCol e
+                in
+                    (vals, (p, outs) :: pols)
+                end
                                         
               | _ => (vals, pols)
 
@@ -1737,22 +1792,20 @@
 
         val (vals, pols) = foldl decl ([], []) file
     in
-        app (fn (loc, e, p) =>
+        app (fn ((loc, e, p), fl) =>
                 let
                     fun doOne e =
                         let
                             val p = And (p, Reln (Eq, [Var 0, e]))
                         in
-                            if List.exists (fn pol => if imply (p, pol) then
-                                                          (if !debug then
-                                                               Print.prefaces "Match"
-                                                                              [("Hyp", p_prop p),
-                                                                               ("Goal", p_prop pol)]
-                                                           else
-                                                               ();
-                                                           true)
-                                                      else
-                                                          false) pols then
+                            if decomp true (fn (e1, e2) => e1 andalso e2 ()) p
+                                      (fn hyps =>
+                                          (fl <> Control Where
+                                           andalso imply (hyps, [AReln (Known, [Var 0])], [Var 0]))
+                                          orelse List.exists (fn (p', outs) =>
+                                                                 decomp false (fn (e1, e2) => e1 orelse e2 ()) p'
+                                                                        (fn goals => imply (hyps, goals, outs)))
+                                                             pols) then
                                 ()
                             else
                                 (ErrorMsg.errorAt loc "The information flow policy may be violated here.";