diff src/iflow.sml @ 1212:fc33072c4d33

Replaced Select predicate with special-case handling for one-or-no-rows queries
author Adam Chlipala <adamc@hcoop.net>
date Tue, 06 Apr 2010 15:17:28 -0400
parents 1d4d65245dd3
children e791d93d4616
line wrap: on
line diff
--- a/src/iflow.sml	Tue Apr 06 13:59:16 2010 -0400
+++ b/src/iflow.sml	Tue Apr 06 15:17:28 2010 -0400
@@ -82,7 +82,7 @@
        | And of prop * prop
        | Or of prop * prop
        | Reln of reln * exp list
-       | Select of int * lvar * lvar * prop * exp
+       | Cond of exp * prop
 
 local
     open Print
@@ -162,14 +162,13 @@
                             p_prop p2,
                             string ")"]
       | Reln (r, es) => p_reln r es
-      | Select (n1, n2, n3, p, e) => box [string ("select(x" ^ Int.toString n1
-                                                  ^ ",X" ^ Int.toString n2
-                                                  ^ ",X" ^ Int.toString n3
-                                                  ^ "){"),
-                                          p_prop p,
-                                          string "}{",
-                                          p_exp e,
-                                          string "}"]
+      | Cond (e, p) => box [string "(",
+                            p_exp e,
+                            space,
+                            string "==",
+                            space,
+                            p_prop p,
+                            string ")"]
 
 end
 
@@ -185,36 +184,6 @@
     end
 end
 
-fun subExp (v, lv) =
-    let
-        fun sub e =
-            case e of
-                Const _ => e
-              | Var v' => if v' = v then Lvar lv else e
-              | Lvar _ => e
-              | Func (f, es) => Func (f, map sub es)
-              | Recd xes => Recd (map (fn (x, e) => (x, sub e)) xes)
-              | Proj (e, s) => Proj (sub e, s)
-              | Finish => Finish
-    in
-        sub
-    end
-
-fun subProp (v, lv) =
-    let
-        fun sub p =
-            case p of
-                True => p
-              | False => p
-              | Unknown => p
-              | And (p1, p2) => And (sub p1, sub p2)
-              | Or (p1, p2) => Or (sub p1, sub p2)
-              | Reln (r, es) => Reln (r, map (subExp (v, lv)) es)
-              | Select (v1, lv1, lv2, p, e) => Select (v1, lv1, lv2, sub p, subExp (v, lv) e)
-    in
-        sub
-    end
-
 fun isKnown e =
     case e of
         Const _ => true
@@ -280,6 +249,15 @@
                  Proj (e', s))
       | Finish => Finish
 
+datatype atom =
+         AReln of reln * exp list
+       | ACond of exp * prop
+
+fun p_atom a =
+    p_prop (case a of
+                AReln x => Reln x
+              | ACond x => Cond x)
+
 fun decomp fals or =
     let
         fun decomp p k =
@@ -293,8 +271,8 @@
                                             k (ps1 @ ps2)))
               | Or (p1, p2) =>
                 or (decomp p1 k, fn () => decomp p2 k)
-              | Reln x => k [x]
-              | Select _ => k []
+              | Reln x => k [AReln x]
+              | Cond x => k [ACond x]
     in
         decomp
     end
@@ -314,6 +292,51 @@
         lvi
     end
 
+fun lvarInP lv =
+    let
+        fun lvi p =
+            case p of
+                True => false
+              | False => false
+              | Unknown => true
+              | And (p1, p2) => lvi p1 orelse lvi p2
+              | Or (p1, p2) => lvi p1 orelse lvi p2
+              | Reln (_, es) => List.exists (lvarIn lv) es
+              | Cond (e, p) => lvarIn lv e orelse lvi p
+    in
+        lvi
+    end
+
+fun varIn lv =
+    let
+        fun lvi e =
+            case e of
+                Const _ => false
+              | Lvar _ => false
+              | Var lv' => lv' = lv
+              | Func (_, es) => List.exists lvi es
+              | Recd xes => List.exists (lvi o #2) xes
+              | Proj (e, _) => lvi e
+              | Finish => false
+    in
+        lvi
+    end
+
+fun varInP lv =
+    let
+        fun lvi p =
+            case p of
+                True => false
+              | False => false
+              | Unknown => false
+              | And (p1, p2) => lvi p1 orelse lvi p2
+              | Or (p1, p2) => lvi p1 orelse lvi p2
+              | Reln (_, es) => List.exists (varIn lv) es
+              | Cond (e, p) => varIn lv e orelse lvi p
+    in
+        lvi
+    end
+
 fun eq' (e1, e2) =
     case (e1, e2) of
         (Const p1, Const p2) => Prim.equal (p1, p2)
@@ -399,32 +422,6 @@
         NONE => e
       | SOME (_, e2) => lookup (t, e2)
 
-fun assert (t, e1, e2) =
-    let
-        val r1 = lookup (t, e1)
-        val r2 = lookup (t, e2)
-    in
-        if eq (r1, r2) then
-            t
-        else
-            (r1, r2) :: t
-    end
-
-open Print
-
-fun query (t, e1, e2) =
-    (if !debug then
-         prefaces "CC query" [("e1", p_exp (simplify e1)),
-                              ("e2", p_exp (simplify e2)),
-                              ("t", p_list (fn (e1, e2) => box [p_exp (simplify e1),
-                                                                space,
-                                                                PD.string "->",
-                                                                space,
-                                                                p_exp (simplify e2)]) t)]
-     else
-         ();
-     eq (lookup (t, e1), lookup (t, e2)))
-
 fun allPeers (t, e) =
     let
         val r = lookup (t, e)
@@ -440,6 +437,49 @@
                                  end) t
     end
 
+fun assert (t, e1, e2) =
+    let
+        val r1 = lookup (t, e1)
+        val r2 = lookup (t, e2)
+
+        fun doUn (t', e1, e2) =
+            case e2 of
+                Func (f, [e]) =>
+                if String.isPrefix "un" f then
+                    let
+                        val f' = String.extract (f, 2, NONE)
+                    in
+                        foldl (fn (e', t') =>
+                                  case e' of
+                                      Func (f'', [e'']) =>
+                                      if f'' = f' then
+                                          (lookup (t', e1), e'') :: t'
+                                      else
+                                          t'
+                                    | _ => t') t' (allPeers (t, e))
+                    end
+                else
+                    t'
+              | _ => t'
+    in
+        if eq (r1, r2) then
+            t
+        else
+            doUn (doUn ((r1, r2) :: t, e1, e2), e2, e1)
+    end
+
+open Print
+
+fun query (t, e1, e2) =
+    ((*prefaces "CC query" [("e1", p_exp (simplify e1)),
+                          ("e2", p_exp (simplify e2)),
+                          ("t", p_list (fn (e1, e2) => box [p_exp (simplify e1),
+                                                            space,
+                                                            PD.string "->",
+                                                            space,
+                                                            p_exp (simplify e2)]) t)];*)
+     eq (lookup (t, e1), lookup (t, e2)))
+
 end
 
 fun rimp cc ((r1, es1), (r2, es2)) =
@@ -504,13 +544,14 @@
                                   let
                                       val cc = foldl (fn (p, cc) =>
                                                          case p of
-                                                             (Eq, [e1, e2]) => Cc.assert (cc, e1, e2)
+                                                             AReln (Eq, [e1, e2]) => Cc.assert (cc, e1, e2)
                                                            | _ => cc) Cc.empty hyps
 
                                       fun gls goals onFail =
                                           case goals of
                                               [] => true
-                                            | g :: goals =>
+                                            | ACond _ :: _ => false
+                                            | AReln g :: goals =>
                                               case (doKnown, g) of
                                                   (false, (Known, _)) => gls goals onFail
                                                 | _ =>
@@ -518,7 +559,8 @@
                                                       fun hps hyps =
                                                           case hyps of
                                                               [] => onFail ()
-                                                            | h :: hyps =>
+                                                            | ACond _ :: hyps => hps hyps
+                                                            | AReln h :: hyps =>
                                                               let
                                                                   val saved = save ()
                                                               in
@@ -540,8 +582,8 @@
                                                       orelse hps hyps
                                                   end
                                   in
-                                      if List.exists (fn (DtCon c1, [e]) =>
-                                                         List.exists (fn (DtCon c2, [e']) =>
+                                      if List.exists (fn AReln (DtCon c1, [e]) =>
+                                                         List.exists (fn AReln (DtCon c2, [e']) =>
                                                                          c1 <> c2 andalso
                                                                          Cc.query (cc, e, e')
                                                                        | _ => false) hyps
@@ -553,9 +595,9 @@
                                          orelse gls goals (fn () => false) then
                                           true
                                       else
-                                          (Print.prefaces "Can't prove"
-                                                          [("hyps", Print.p_list (fn x => p_prop (Reln x)) hyps),
-                                                           ("goals", Print.p_list (fn x => p_prop (Reln x)) goals)];
+                                          ((*Print.prefaces "Can't prove"
+                                                          [("hyps", Print.p_list p_atom hyps),
+                                                           ("goals", Print.p_list p_atom goals)];*)
                                            false)
                                   end))
     in
@@ -569,8 +611,6 @@
         PConVar n => "C" ^ Int.toString n
       | PConFfi {mod = m, datatyp = d, con = c, ...} => m ^ "." ^ d ^ "." ^ c
 
-
-
 datatype chunk =
          String of string
        | Exp of Mono.exp
@@ -871,6 +911,10 @@
                 x :: ls
         end
 
+datatype queryMode =
+         SomeCol of exp
+       | AllCols of exp
+
 fun queryProp env rv oe e =
     case parse query e of
         NONE => (print ("Warning: Information flow checker can't parse SQL query at "
@@ -899,7 +943,7 @@
                                Reln (Sql t,
                                      [Recd (foldl (fn ((v', f), fs) =>
                                                       if v' = v then
-                                                          (f, Proj (Proj (Lvar rv, v), f)) :: fs
+                                                          (f, Proj (Proj (rv, v), f)) :: fs
                                                       else
                                                           fs) [] allUsed)])))
                       True (#From r)
@@ -907,7 +951,7 @@
             fun expIn e =
                 case e of
                     SqConst p => inl (Const p)
-                  | Field (v, f) => inl (Proj (Proj (Lvar rv, v), f))
+                  | Field (v, f) => inl (Proj (Proj (rv, v), f))
                   | Binop (bo, e1, e2) =>
                     inr (case (bo, expIn e1, expIn e2) of
                              (Exps f, inl e1, inl e2) => f (e1, e2)
@@ -931,7 +975,7 @@
                     inl (case expIn e of
                          inl e => Func (f, [e])
                        | _ => raise Fail ("Iflow: non-expresion passed to function " ^ f))
-                  | Count => inl (Func ("COUNT", []))
+                  | Count => inl (Proj (rv, "$COUNT"))
 
             val p = case #Where r of
                         NONE => p
@@ -940,13 +984,12 @@
                             inr p' => And (p, p')
                           | _ => p
         in
-            (case oe of
-                 NONE => p
-               | SOME oe =>
-                 And (p, foldl (fn (si, p) =>
+            (And (p, case oe of
+                         SomeCol oe =>
+                         foldl (fn (si, p) =>
                                    let
                                        val p' = case si of
-                                                    SqField (v, f) => Reln (Eq, [oe, Proj (Proj (Lvar rv, v), f)])
+                                                    SqField (v, f) => Reln (Eq, [oe, Proj (Proj (rv, v), f)])
                                                   | SqExp (e, f) =>
                                                     case expIn e of
                                                         inr _ => Unknown
@@ -954,11 +997,25 @@
                                    in
                                        Or (p, p')
                                    end)
-                               False (#Select r)),
+                               False (#Select r)
+                       | AllCols oe =>
+                         foldl (fn (si, p) =>
+                                   let
+                                       val p' = case si of
+                                                    SqField (v, f) => Reln (Eq, [Proj (Proj (oe, v), f),
+                                                                                 Proj (Proj (rv, v), f)])
+                                                  | SqExp (e, f) =>
+                                                    case expIn e of
+                                                        inr p => Cond (Proj (oe, f), p)
+                                                      | inl e => Reln (Eq, [Proj (oe, f), e])
+                                   in
+                                       And (p, p')
+                                   end)
+                               True (#Select r)),
              
              case #Where r of
                  NONE => []
-               | SOME e => map (fn (v, f) => Proj (Proj (Lvar rv, v), f)) (usedFields e))
+               | SOME e => map (fn (v, f) => Proj (Proj (rv, v), f)) (usedFields e))
         end
 
 fun evalPat env e (pt, _) =
@@ -996,9 +1053,7 @@
       | (And (x1, y1), And (x2, y2)) => peq (x1, x2) andalso peq (y1, y2)
       | (Or (x1, y1), Or (x2, y2)) => peq (x1, x2) andalso peq (y1, y2)
       | (Reln (r1, es1), Reln (r2, es2)) => r1 = r2 andalso ListPair.allEq eeq (es1, es2)
-      | (Select (n1, n2, n3, p1, e1), Select (n1', n2', n3', p2, e2)) =>
-        n1 = n1' andalso n2 = n2' andalso n3 = n3'
-        andalso peq (p1, p2) andalso eeq (e1, e2)
+      | (Cond (e1, p1), Cond (e2, p2)) => eeq (e1, e2) andalso peq (p1, p2)
       | _ => false
 
 fun removeRedundant p1 =
@@ -1010,7 +1065,6 @@
                 case p2 of
                     And (x, y) => And (rr x, rr y)
                   | Or (x, y) => Or (rr x, rr y)
-                  | Select (n1, n2, n3, p, e) => Select (n1, n2, n3, rr p, e)
                   | _ => p2
     in
         rr
@@ -1164,28 +1218,35 @@
                 val (i, st) = evalExp env (i, st)
 
                 val r = #1 st
-                val acc = #1 st + 1
-                val st' = (#1 st + 2, #2 st, #3 st)
+                val rv = #1 st + 1
+                val acc = #1 st + 2
+                val st' = (#1 st + 3, #2 st, #3 st)
 
                 val (b, st') = evalExp (Var acc :: Var r :: env) (b, st')
 
-                val r' = newLvar ()
-                val acc' = newLvar ()
-                val (qp, used) = queryProp env r' NONE q
+                val (qp, used) = queryProp env (Var rv) (AllCols (Var r)) q
 
-                val doSubExp = subExp (r, r') o subExp (acc, acc')
-                val doSubProp = subProp (r, r') o subProp (acc, acc')
+                val p' = And (qp, #2 st')
 
-                val p = doSubProp (#2 st')
-                val p' = And (p, qp)
-                val p = Select (r, r', acc', p', doSubExp b)
+                val (nvs, p, res) = if varInP acc (#2 st') then
+                                        (#1 st + 1, #2 st, Var r)
+                                    else
+                                        let
+                                            val out = #1 st'
 
-                val sent = map (fn (loc, e, p) => (loc,
-                                                   doSubExp e,
-                                                   And (qp, doSubProp p))) (#3 st')
+                                            val p = Or (Reln (Eq, [Var out, i]),
+                                                        And (Reln (Eq, [Var out, b]),
+                                                             And (Reln (Gt, [Proj (Var rv, "$COUNT"),
+                                                                             Const (Prim.Int 0)]),
+                                                                  p')))
+                                        in
+                                            (out + 1, p, Var out)
+                                        end
+
+                val sent = map (fn (loc, e, p) => (loc, e, And (qp, p))) (#3 st')
                 val sent = map (fn e => (loc, e, p')) used @ sent
             in
-                (Var r, (#1 st + 1, And (#2 st, p), sent))
+                (res, (nvs, p, sent))
             end
           | EDml _ => default ()
           | ENextval _ => default ()
@@ -1231,7 +1292,7 @@
                     (sent @ vals, pols)
                 end
 
-              | DPolicy (PolQuery e) => (vals, #1 (queryProp [] 0 (SOME (Var 0)) e) :: pols)
+              | DPolicy (PolQuery e) => (vals, #1 (queryProp [] (Lvar 0) (SomeCol (Var 0)) e) :: pols)
 
               | _ => (vals, pols)