diff src/iflow.sml @ 1211:1d4d65245dd3

About to try removing Select predicate
author Adam Chlipala <adamc@hcoop.net>
date Tue, 06 Apr 2010 13:59:16 -0400
parents c5bd970e77a5
children fc33072c4d33
line wrap: on
line diff
--- a/src/iflow.sml	Tue Apr 06 12:04:08 2010 -0400
+++ b/src/iflow.sml	Tue Apr 06 13:59:16 2010 -0400
@@ -67,6 +67,7 @@
 datatype reln =
          Known
        | Sql of string
+       | DtCon of string
        | Eq
        | Ne
        | Lt
@@ -127,6 +128,9 @@
       | Sql s => box [string (s ^ "("),
                       p_list p_exp es,
                       string ")"]
+      | DtCon s => box [string (s ^ "("),
+                        p_list p_exp es,
+                        string ")"]
       | Eq => p_bop "=" es
       | Ne => p_bop "<>" es
       | Lt => p_bop "<" es
@@ -241,11 +245,20 @@
       | Func (f, es) =>
         let
             val es = map simplify es
+
+            fun default () = Func (f, es)
         in
             if List.exists isFinish es then
                 Finish
+            else if String.isPrefix "un" f then
+                case es of
+                    [Func (f', [e])] => if f' = String.extract (f, 2, NONE) then
+                                            e
+                                        else
+                                            default ()
+                  | _ => default ()
             else
-                Func (f, es)
+                default ()
         end
       | Recd xes =>
         let
@@ -351,10 +364,21 @@
              false)
     end
 
-exception Imply of prop * prop
-
 val debug = ref false
 
+fun eeq (e1, e2) =
+    case (e1, e2) of
+        (Const p1, Const p2) => Prim.equal (p1, p2)
+      | (Var n1, Var n2) => n1 = n2
+      | (Lvar n1, Lvar n2) => n1 = n2
+      | (Func (f1, es1), Func (f2, es2)) => f1 = f2 andalso ListPair.allEq eeq (es1, es2)
+      | (Recd xes1, Recd xes2) => length xes1 = length xes2 andalso
+                                  List.all (fn (x2, e2) =>
+                                               List.exists (fn (x1, e1) => x1 = x2 andalso eeq (e1, e2)) xes2) xes1
+      | (Proj (e1, x1), Proj (e2, x2)) => eeq (e1, e2) andalso x1 = x2
+      | (Finish, Finish) => true
+      | _ => false
+             
 (* Congruence closure *)
 structure Cc :> sig
     type t
@@ -364,20 +388,7 @@
     val allPeers : t * exp -> exp list
 end = struct
 
-fun eq' (e1, e2) =
-    case (e1, e2) of
-        (Const p1, Const p2) => Prim.equal (p1, p2)
-      | (Var n1, Var n2) => n1 = n2
-      | (Lvar n1, Lvar n2) => n1 = n2
-      | (Func (f1, es1), Func (f2, es2)) => f1 = f2 andalso ListPair.allEq eq' (es1, es2)
-      | (Recd xes1, Recd xes2) => length xes1 = length xes2 andalso
-                                  List.all (fn (x2, e2) =>
-                                               List.exists (fn (x1, e1) => x1 = x2 andalso eq' (e1, e2)) xes2) xes1
-      | (Proj (e1, x1), Proj (e2, x2)) => eq' (e1, e2) andalso x1 = x2
-      | (Finish, Finish) => true
-      | _ => false
-
-fun eq (e1, e2) = eq' (simplify e1, simplify e2)
+fun eq (e1, e2) = eeq (simplify e1, simplify e2)
 
 type t = (exp * exp) list
 
@@ -475,6 +486,7 @@
                      case e of
                          Var v' => v' = v
                        | Proj (e, _) => matches e
+                       | Func (f, [e]) => String.isPrefix "un" f andalso matches e
                        | _ => false
              in
                  List.exists matches (Cc.allPeers (cc, e2))
@@ -528,7 +540,23 @@
                                                       orelse hps hyps
                                                   end
                                   in
-                                      gls goals (fn () => false)
+                                      if List.exists (fn (DtCon c1, [e]) =>
+                                                         List.exists (fn (DtCon c2, [e']) =>
+                                                                         c1 <> c2 andalso
+                                                                         Cc.query (cc, e, e')
+                                                                       | _ => false) hyps
+                                                         orelse List.exists (fn Func (c2, []) => c1 <> c2
+                                                                              | Finish => true
+                                                                              | _ => false)
+                                                                            (Cc.allPeers (cc, e))
+                                                       | _ => false) hyps
+                                         orelse gls goals (fn () => false) then
+                                          true
+                                      else
+                                          (Print.prefaces "Can't prove"
+                                                          [("hyps", Print.p_list (fn x => p_prop (Reln x)) hyps),
+                                                           ("goals", Print.p_list (fn x => p_prop (Reln x)) goals)];
+                                           false)
                                   end))
     in
         reset ();
@@ -668,17 +696,17 @@
 
 val ident = keep (fn ch => Char.isAlphaNum ch orelse ch = #"_")
 
-val t_ident = wrap ident (fn s => if String.isPrefix "T_" s then
-                                      String.extract (s, 2, NONE)
-                                  else
-                                      raise Fail "Iflow: Bad table variable")
-val uw_ident = wrap ident (fn s => if String.isPrefix "uw_" s andalso size s >= 4 then
-                                       str (Char.toUpper (String.sub (s, 3)))
-                                       ^ String.extract (s, 4, NONE)
+val t_ident = wrapP ident (fn s => if String.isPrefix "T_" s then
+                                       SOME (String.extract (s, 2, NONE))
                                    else
-                                       raise Fail "Iflow: Bad uw_* variable")
+                                       NONE)
+val uw_ident = wrapP ident (fn s => if String.isPrefix "uw_" s andalso size s >= 4 then
+                                        SOME (str (Char.toUpper (String.sub (s, 3)))
+                                              ^ String.extract (s, 4, NONE))
+                                    else
+                                        NONE)
 
-val sitem = wrap (follow t_ident
+val field = wrap (follow t_ident
                          (follow (const ".")
                                  uw_ident))
                  (fn (t, ((), f)) => (t, f))
@@ -693,6 +721,8 @@
        | Binop of Rel * sqexp * sqexp
        | SqKnown of sqexp
        | Inj of Mono.exp
+       | SqFunc of string * sqexp
+       | Count
 
 fun cmp s r = wrap (const s) (fn () => Exps (fn (e1, e2) => Reln (r, [e1, e2])))
 
@@ -758,12 +788,25 @@
             NONE
       | _ => NONE
 
+fun constK s = wrap (const s) (fn () => s)
+
+val funcName = altL [constK "COUNT",
+                     constK "MIN",
+                     constK "MAX",
+                     constK "SUM",
+                     constK "AVG"]
+
 fun sqexp chs =
     log "sqexp"
     (altL [wrap prim SqConst,
-           wrap sitem Field,
+           wrap field Field,
            wrap known SqKnown,
+           wrap func SqFunc,
+           wrap (const "COUNT(*)") (fn () => Count),
            wrap sqlify Inj,
+           wrap (follow (const "COALESCE(") (follow sqexp (follow (const ",")
+                                                                  (follow (keep (fn ch => ch <> #")")) (const ")")))))
+                (fn ((), (e, _)) => e),
            wrap (follow (ws (const "("))
                         (follow (wrap
                                      (follow sqexp
@@ -782,7 +825,18 @@
     chs
 
 and known chs = wrap (follow known' (follow (const "(") (follow sqexp (const ")"))))
-                (fn ((), ((), (e, ()))) => e) chs
+                     (fn ((), ((), (e, ()))) => e) chs
+                
+and func chs = wrap (follow funcName (follow (const "(") (follow sqexp (const ")"))))
+                    (fn (f, ((), (e, ()))) => (f, e)) chs
+
+datatype sitem =
+         SqField of string * string
+       | SqExp of sqexp * string
+
+val sitem = alt (wrap field SqField)
+            (wrap (follow sqexp (follow (const " AS ") uw_ident))
+             (fn (e, ((), s)) => SqExp (e, s)))
 
 val select = log "select"
              (wrap (follow (const "SELECT ") (list sitem))
@@ -804,6 +858,19 @@
                 (wrap (follow (follow select from) (opt wher))
                       (fn ((fs, ts), wher) => {Select = fs, From = ts, Where = wher}))
 
+fun removeDups ls =
+    case ls of
+        [] => []
+      | x :: ls =>
+        let
+            val ls = removeDups ls
+        in
+            if List.exists (fn x' => x' = x) ls then
+                ls  
+            else
+                x :: ls
+        end
+
 fun queryProp env rv oe e =
     case parse query e of
         NONE => (print ("Warning: Information flow checker can't parse SQL query at "
@@ -811,6 +878,21 @@
                  (Unknown, []))
       | SOME r =>
         let
+            fun usedFields e =
+                case e of
+                    SqConst _ => []
+                  | Field (v, f) => [(v, f)]
+                  | Binop (_, e1, e2) => removeDups (usedFields e1 @ usedFields e2)
+                  | SqKnown _ => []
+                  | Inj _ => []
+                  | SqFunc (_, e) => usedFields e
+                  | Count => []
+
+            val allUsed = removeDups (List.mapPartial (fn SqField x => SOME x | _ => NONE) (#Select r)
+                                      @ (case #Where r of
+                                             NONE => []
+                                           | SOME e => usedFields e))
+
             val p =
                 foldl (fn ((t, v), p) =>
                           And (p,
@@ -819,7 +901,7 @@
                                                       if v' = v then
                                                           (f, Proj (Proj (Lvar rv, v), f)) :: fs
                                                       else
-                                                          fs) [] (#Select r))])))
+                                                          fs) [] allUsed)])))
                       True (#From r)
 
             fun expIn e =
@@ -845,14 +927,11 @@
                     in
                         inl (deinj e)
                     end
-
-            fun usedFields e =
-                case e of
-                    SqConst _ => []
-                  | Field (v, f) => [Proj (Proj (Lvar rv, v), f)]
-                  | Binop (_, e1, e2) => usedFields e1 @ usedFields e2
-                  | SqKnown _ => []
-                  | Inj _ => []
+                  | SqFunc (f, e) =>
+                    inl (case expIn e of
+                         inl e => Func (f, [e])
+                       | _ => raise Fail ("Iflow: non-expresion passed to function " ^ f))
+                  | Count => inl (Func ("COUNT", []))
 
             val p = case #Where r of
                         NONE => p
@@ -864,16 +943,79 @@
             (case oe of
                  NONE => p
                | SOME oe =>
-                 And (p, foldl (fn ((v, f), p) =>
-                                   Or (p,
-                                       Reln (Eq, [oe, Proj (Proj (Lvar rv, v), f)])))
+                 And (p, foldl (fn (si, p) =>
+                                   let
+                                       val p' = case si of
+                                                    SqField (v, f) => Reln (Eq, [oe, Proj (Proj (Lvar rv, v), f)])
+                                                  | SqExp (e, f) =>
+                                                    case expIn e of
+                                                        inr _ => Unknown
+                                                      | inl e => Reln (Eq, [oe, e])
+                                   in
+                                       Or (p, p')
+                                   end)
                                False (#Select r)),
              
              case #Where r of
                  NONE => []
-               | SOME e => usedFields e)
+               | SOME e => map (fn (v, f) => Proj (Proj (Lvar rv, v), f)) (usedFields e))
         end
 
+fun evalPat env e (pt, _) =
+    case pt of
+        PWild => (env, True)
+      | PVar _ => (e :: env, True)
+      | PPrim _ => (env, True)
+      | PCon (_, pc, NONE) => (env, Reln (DtCon (patCon pc), [e]))
+      | PCon (_, pc, SOME pt) =>
+        let
+            val (env, p) = evalPat env (Func ("un" ^ patCon pc, [e])) pt
+        in
+            (env, And (p, Reln (DtCon (patCon pc), [e])))
+        end
+      | PRecord xpts =>
+        foldl (fn ((x, pt, _), (env, p)) =>
+                  let
+                      val (env, p') = evalPat env (Proj (e, x)) pt
+                  in
+                      (env, And (p', p))
+                  end) (env, True) xpts
+      | PNone _ => (env, Reln (DtCon "None", [e]))
+      | PSome (_, pt) =>
+        let
+            val (env, p) = evalPat env (Func ("unSome", [e])) pt
+        in
+            (env, And (p, Reln (DtCon "Some", [e])))
+        end
+
+fun peq (p1, p2) =
+    case (p1, p2) of
+        (True, True) => true
+      | (False, False) => true
+      | (Unknown, Unknown) => true
+      | (And (x1, y1), And (x2, y2)) => peq (x1, x2) andalso peq (y1, y2)
+      | (Or (x1, y1), Or (x2, y2)) => peq (x1, x2) andalso peq (y1, y2)
+      | (Reln (r1, es1), Reln (r2, es2)) => r1 = r2 andalso ListPair.allEq eeq (es1, es2)
+      | (Select (n1, n2, n3, p1, e1), Select (n1', n2', n3', p2, e2)) =>
+        n1 = n1' andalso n2 = n2' andalso n3 = n3'
+        andalso peq (p1, p2) andalso eeq (e1, e2)
+      | _ => false
+
+fun removeRedundant p1 =
+    let
+        fun rr p2 =
+            if peq (p1, p2) then
+                True
+            else
+                case p2 of
+                    And (x, y) => And (rr x, rr y)
+                  | Or (x, y) => Or (rr x, rr y)
+                  | Select (n1, n2, n3, p, e) => Select (n1, n2, n3, rr p, e)
+                  | _ => p2
+    in
+        rr
+    end
+
 fun evalExp env (e as (_, loc), st as (nv, p, sent)) =
     let
         fun default () =
@@ -951,7 +1093,25 @@
             in
                 (Proj (e, s), st)
             end
-          | ECase _ => default ()
+          | ECase (e, pes, _) =>
+            let
+                val (e, st) = evalExp env (e, st)
+                val r = #1 st
+                val st = (r + 1, #2 st, #3 st)
+                val orig = #2 st
+
+                val st = foldl (fn ((pt, pe), st) =>
+                                   let
+                                       val (env, pp) = evalPat env e pt
+                                       val (pe, st') = evalExp env (pe, (#1 st, And (orig, pp), #3 st))
+                                                       
+                                       val this = And (removeRedundant orig (#2 st'), Reln (Eq, [Var r, pe]))
+                                   in
+                                       (#1 st', Or (#2 st, this), #3 st')
+                                   end) (#1 st, False, #3 st) pes
+            in
+                (Var r, (#1 st, And (orig, #2 st), #3 st))
+            end
           | EStrcat (e1, e2) =>
             let
                 val (e1, st) = evalExp env (e1, st)