diff src/iflow.sml @ 1214:648e6b087dfb

Change query_policy to sendClient; all arguments passed to SQL predicates are variables
author Adam Chlipala <adamc@hcoop.net>
date Thu, 08 Apr 2010 09:57:37 -0400
parents e791d93d4616
children 360f1ed0a969
line wrap: on
line diff
--- a/src/iflow.sml	Tue Apr 06 16:14:19 2010 -0400
+++ b/src/iflow.sml	Thu Apr 08 09:57:37 2010 -0400
@@ -412,6 +412,7 @@
     val assert : t * exp * exp -> t
     val query : t * exp * exp -> bool
     val allPeers : t * exp -> exp list
+    val p_t : t Print.printer
 end = struct
 
 fun eq (e1, e2) = eeq (simplify e1, simplify e2)
@@ -440,50 +441,102 @@
                                  end) t
     end
 
+open Print
+
+val p_t = p_list (fn (e1, e2) => box [p_exp (simplify e1),
+                                      space,
+                                      PD.string "->",
+                                      space,
+                                      p_exp (simplify e2)])
+
+fun query (t, e1, e2) =
+    let
+        fun doUn e =
+            case e of
+                Func (f, [e1]) =>
+                if String.isPrefix "un" f then
+                    let
+                        val s = String.extract (f, 2, NONE)
+                    in
+                        case ListUtil.search (fn e =>
+                                                 case e of
+                                                     Func (f', [e']) =>
+                                                     if f' = s then
+                                                         SOME e'
+                                                     else
+                                                         NONE
+                                                   | _ => NONE) (allPeers (t, e1)) of
+                            NONE => e
+                          | SOME e => doUn e
+                    end
+                else
+                    e
+              | _ => e
+
+        val e1' = doUn (lookup (t, doUn (simplify e1)))
+        val e2' = doUn (lookup (t, doUn (simplify e2)))
+    in
+        (*prefaces "CC query" [("e1", p_exp (simplify e1)),
+                             ("e2", p_exp (simplify e2)),
+                             ("e1'", p_exp (simplify e1')),
+                             ("e2'", p_exp (simplify e2')),
+                             ("t", p_t t)];*)
+        eq (e1', e2')
+    end
+
 fun assert (t, e1, e2) =
     let
         val r1 = lookup (t, e1)
         val r2 = lookup (t, e2)
-
-        fun doUn k (t', e1, e2) =
-            case e2 of
-                Func (f, [e]) =>
-                if String.isPrefix "un" f then
-                    let
-                        val f' = String.extract (f, 2, NONE)
-                    in
-                        foldl (fn (e', t') =>
-                                  case e' of
-                                      Func (f'', [e'']) =>
-                                      if f'' = f' then
-                                          (lookup (t', e1), k e'') :: t'
-                                      else
-                                          t'
-                                    | _ => t') t' (allPeers (t, e))
-                    end
-                else
-                    t'
-              | Proj (e2, f) => doUn (fn e' => k (Proj (e', f))) (t', e1, e2)
-              | _ => t'
     in
         if eq (r1, r2) then
             t
         else
-            doUn (fn x => x) (doUn (fn x => x) ((r1, r2) :: t, e1, e2), e2, e1)
+            let
+                fun doUn (t, e1, e2) =
+                    case e1 of
+                        Func (f, [e]) => if String.isPrefix "un" f then
+                                             let
+                                                 val s = String.extract (f, 2, NONE)
+                                             in
+                                                 foldl (fn (e', t) =>
+                                                           case e' of
+                                                               Func (f', [e']) =>
+                                                               if f' = s then
+                                                                   assert (assert (t, e', e1), e', e2)
+                                                               else
+                                                                   t
+                                                             | _ => t) t (allPeers (t, e))
+                                             end
+                                         else
+                                             t
+                      | _ => t
+
+                fun doProj (t, e1, e2) =
+                    foldl (fn ((e1', e2'), t) =>
+                              let
+                                  fun doOne (e, t) =
+                                      case e of
+                                          Proj (e', f) =>
+                                          if query (t, e1, e') then
+                                              assert (t, e, Proj (e2, f))
+                                          else
+                                              t
+                                        | _ => t
+                              in
+                                  doOne (e1', doOne (e2', t))
+                              end) t t
+
+                val t = (r1, r2) :: t
+                val t = doUn (t, r1, r2)
+                val t = doUn (t, r2, r1)
+                val t = doProj (t, r1, r2)
+                val t = doProj (t, r2, r1)
+            in
+                t
+            end
     end
 
-open Print
-
-fun query (t, e1, e2) =
-    ((*prefaces "CC query" [("e1", p_exp (simplify e1)),
-                          ("e2", p_exp (simplify e2)),
-                          ("t", p_list (fn (e1, e2) => box [p_exp (simplify e1),
-                                                            space,
-                                                            PD.string "->",
-                                                            space,
-                                                            p_exp (simplify e2)]) t)];*)
-     eq (lookup (t, e1), lookup (t, e2)))
-
 end
 
 fun rimp cc ((r1, es1), (r2, es2)) =
@@ -491,19 +544,7 @@
         (Sql r1', Sql r2') =>
         r1' = r2' andalso
         (case (es1, es2) of
-             ([Recd xes1], [Recd xes2]) =>
-             let
-                 val saved = save ()
-             in
-                 if List.all (fn (f, e2) =>
-                                 case ListUtil.search (fn (f', e1) => if f' = f then SOME e1 else NONE) xes1 of
-                                     NONE => true
-                                   | SOME e1 => eq (e1, e2)) xes2 then
-                     true
-                 else
-                     (restore saved;
-                      false)
-             end
+             ([e1], [e2]) => eq (e1, e2)
            | _ => false)
       | (Eq, Eq) =>
         (case (es1, es2) of
@@ -533,6 +574,9 @@
                        | Func (f, [e]) => String.isPrefix "un" f andalso matches e
                        | _ => false
              in
+                 (*Print.prefaces "Checking peers" [("e2", p_exp e2),
+                                                  ("peers", Print.p_list p_exp (Cc.allPeers (cc, e2))),
+                                                  ("db", Cc.p_t cc)];*)
                  List.exists matches (Cc.allPeers (cc, e2))
              end
            | _ => false)
@@ -562,7 +606,8 @@
                                                   let
                                                       fun hps hyps =
                                                           case hyps of
-                                                              [] => ((*Print.preface ("Fail", p_prop (Reln g));*)
+                                                              [] => ((*Print.prefaces "Fail" [("g", p_prop (Reln g)),
+                                                                                            ("db", Cc.p_t cc)];*)
                                                                      onFail ())
                                                             | ACond _ :: hyps => hps hyps
                                                             | AReln h :: hyps =>
@@ -925,13 +970,27 @@
          SomeCol of exp
        | AllCols of exp
 
-fun queryProp env rv oe e =
+fun queryProp env rvN rv oe e =
     case parse query e of
         NONE => (print ("Warning: Information flow checker can't parse SQL query at "
                         ^ ErrorMsg.spanToString (#2 e) ^ "\n");
-                 (Unknown, []))
+                 (rvN, Var 0, Unknown, []))
       | SOME r =>
         let
+            val (rvN, count) = rv rvN
+
+            val (rvs, rvN) = ListUtil.foldlMap (fn ((_, v), rvN) =>
+                                                   let
+                                                       val (rvN, e) = rv rvN
+                                                   in
+                                                       ((v, e), rvN)
+                                                   end) rvN (#From r)
+
+            fun rvOf v =
+                case List.find (fn (v', _) => v' = v) rvs of
+                    NONE => raise Fail "Iflow.queryProp: Bad table variable"
+                  | SOME (_, e) => e
+
             fun usedFields e =
                 case e of
                     SqConst _ => []
@@ -942,26 +1001,13 @@
                   | SqFunc (_, e) => usedFields e
                   | Count => []
 
-            val allUsed = removeDups (List.mapPartial (fn SqField x => SOME x | _ => NONE) (#Select r)
-                                      @ (case #Where r of
-                                             NONE => []
-                                           | SOME e => usedFields e))
-
             val p =
-                foldl (fn ((t, v), p) =>
-                          And (p,
-                               Reln (Sql t,
-                                     [Recd (foldl (fn ((v', f), fs) =>
-                                                      if v' = v then
-                                                          (f, Proj (Proj (rv, v), f)) :: fs
-                                                      else
-                                                          fs) [] allUsed)])))
-                      True (#From r)
+                foldl (fn ((t, v), p) => And (p, Reln (Sql t, [rvOf v]))) True (#From r)
 
             fun expIn e =
                 case e of
                     SqConst p => inl (Const p)
-                  | Field (v, f) => inl (Proj (Proj (rv, v), f))
+                  | Field (v, f) => inl (Proj (rvOf v, f))
                   | Binop (bo, e1, e2) =>
                     inr (case (bo, expIn e1, expIn e2) of
                              (Exps f, inl e1, inl e2) => f (e1, e2)
@@ -985,7 +1031,7 @@
                     inl (case expIn e of
                          inl e => Func (f, [e])
                        | _ => raise Fail ("Iflow: non-expresion passed to function " ^ f))
-                  | Count => inl (Proj (rv, "$COUNT"))
+                  | Count => inl count
 
             val p = case #Where r of
                         NONE => p
@@ -994,12 +1040,14 @@
                             inr p' => And (p, p')
                           | _ => p
         in
-            (And (p, case oe of
+            (rvN,
+             count,
+             And (p, case oe of
                          SomeCol oe =>
                          foldl (fn (si, p) =>
                                    let
                                        val p' = case si of
-                                                    SqField (v, f) => Reln (Eq, [oe, Proj (Proj (rv, v), f)])
+                                                    SqField (v, f) => Reln (Eq, [oe, Proj (rvOf v, f)])
                                                   | SqExp (e, f) =>
                                                     case expIn e of
                                                         inr _ => Unknown
@@ -1013,7 +1061,7 @@
                                    let
                                        val p' = case si of
                                                     SqField (v, f) => Reln (Eq, [Proj (Proj (oe, v), f),
-                                                                                 Proj (Proj (rv, v), f)])
+                                                                                 Proj (rvOf v, f)])
                                                   | SqExp (e, f) =>
                                                     case expIn e of
                                                         inr p => Cond (Proj (oe, f), p)
@@ -1025,7 +1073,7 @@
              
              case #Where r of
                  NONE => []
-               | SOME e => map (fn (v, f) => Proj (Proj (rv, v), f)) (usedFields e))
+               | SOME e => map (fn (v, f) => Proj (rvOf v, f)) (usedFields e))
         end
 
 fun evalPat env e (pt, _) =
@@ -1118,7 +1166,7 @@
                 let
                     val (es, st) = ListUtil.foldlMap (evalExp env) st es
                 in
-                    (Func ("unit", []), (#1 st, p, foldl (fn (e, sent) => addSent (#2 st, e, sent)) sent es))
+                    (Recd [], (#1 st, p, foldl (fn (e, sent) => addSent (#2 st, e, sent)) sent es))
                 end
             else if Settings.isEffectful (m, s) andalso not (Settings.isBenignEffectful (m, s)) then
                 default ()
@@ -1213,7 +1261,7 @@
             let
                 val (e, st) = evalExp env (e, st)
             in
-                (Func ("unit", []), (#1 st, p, addSent (#2 st, e, sent)))
+                (Recd [], (#1 st, p, addSent (#2 st, e, sent)))
             end
           | ESeq (e1, e2) =>
             let
@@ -1240,13 +1288,15 @@
                 val (i, st) = evalExp env (i, st)
 
                 val r = #1 st
-                val rv = #1 st + 1
-                val acc = #1 st + 2
-                val st' = (#1 st + 3, #2 st, #3 st)
+                val acc = #1 st + 1
+                val st' = (#1 st + 2, #2 st, #3 st)
 
                 val (b, st') = evalExp (Var acc :: Var r :: env) (b, st')
 
-                val (qp, used) = queryProp env (Var rv) (AllCols (Var r)) q
+                val (rvN, count, qp, used) =
+                    queryProp env
+                              (#1 st') (fn rvN => (rvN + 1, Var rvN))
+                              (AllCols (Var r)) q
 
                 val p' = And (qp, #2 st')
 
@@ -1254,11 +1304,11 @@
                                         (#1 st + 1, #2 st, Var r)
                                     else
                                         let
-                                            val out = #1 st'
+                                            val out = rvN
 
                                             val p = Or (Reln (Eq, [Var out, i]),
                                                         And (Reln (Eq, [Var out, b]),
-                                                             And (Reln (Gt, [Proj (Var rv, "$COUNT"),
+                                                             And (Reln (Gt, [count,
                                                                              Const (Prim.Int 0)]),
                                                                   p')))
                                         in
@@ -1323,8 +1373,9 @@
                     (sent @ vals, pols)
                 end
 
-              | DPolicy (PolQuery e) => (vals, #1 (queryProp [] (Lvar 0) (SomeCol (Var 0)) e) :: pols)
-
+              | DPolicy (PolClient e) => (vals, #3 (queryProp [] 0 (fn rvN => (rvN + 1, Lvar rvN))
+                                                              (SomeCol (Var 0)) e) :: pols)
+                                        
               | _ => (vals, pols)
 
         val () = reset ()