diff src/sqlcache.sml @ 2294:f8903af753ff

Support nested queries but disable UrFlow for now.
author Ziv Scully <ziv@mit.edu>
date Thu, 19 Nov 2015 01:59:00 -0500
parents 50ad02829abd
children e6c5bb62fef8
line wrap: on
line diff
--- a/src/sqlcache.sml	Wed Nov 18 14:48:24 2015 -0500
+++ b/src/sqlcache.sml	Thu Nov 19 01:59:00 2015 -0500
@@ -30,11 +30,18 @@
 
 (* Option monad. *)
 fun obind (x, f) = Option.mapPartial f x
-fun oguard (b, x) = if b then x else NONE
+fun oguard (b, x) = if b then x () else NONE
 fun omap f = fn SOME x => SOME (f x) | _ => NONE
 fun omap2 f = fn (SOME x, SOME y) => SOME (f (x,y)) | _ => NONE
 fun osequence ys = List.foldr (omap2 op::) (SOME []) ys
 
+fun concatMap f xs = List.concat (map f xs)
+
+val rec cartesianProduct : 'a list list -> 'a list list =
+ fn [] => [[]]
+  | (xs :: xss) => concatMap (fn ys => concatMap (fn x => [x :: ys]) xs)
+                             (cartesianProduct xss)
+
 fun indexOf test =
     let
         fun f n =
@@ -104,10 +111,12 @@
 val dummyLoc = ErrorMsg.dummySpan
 
 (* DEBUG *)
-fun printExp msg exp = Print.preface ("SQLCACHE: " ^ msg ^ ":", MonoPrint.p_exp MonoEnv.empty exp)
-fun printExp' msg exp' = printExp msg (exp', dummyLoc)
-fun printTyp msg typ = Print.preface ("SQLCACHE: " ^ msg ^ ":", MonoPrint.p_typ MonoEnv.empty typ)
-fun printTyp' msg typ' = printTyp msg (typ', dummyLoc)
+fun printExp msg exp =
+    (Print.preface ("SQLCACHE: " ^ msg ^ ":", MonoPrint.p_exp MonoEnv.empty exp); exp)
+fun printExp' msg exp' = (printExp msg (exp', dummyLoc); exp')
+fun printTyp msg typ =
+    (Print.preface ("SQLCACHE: " ^ msg ^ ":", MonoPrint.p_typ MonoEnv.empty typ); typ)
+fun printTyp' msg typ' = (printTyp msg (typ', dummyLoc); typ')
 fun obindDebug printer (x, f) =
     case x of
         NONE => NONE
@@ -204,13 +213,6 @@
 
 val flipJt = fn Conj => Disj | Disj => Conj
 
-fun concatMap f xs = List.concat (map f xs)
-
-val rec cartesianProduct : 'a list list -> 'a list list =
- fn [] => [[]]
-  | (xs :: xss) => concatMap (fn ys => concatMap (fn x => [x :: ys]) xs)
-                             (cartesianProduct xss)
-
 (* Pushes all negation to the atoms.*)
 fun pushNegate (normalizeAtom : bool * 'atom -> 'atom) (negating : bool) =
  fn Atom x => Atom' (normalizeAtom (negating, x))
@@ -349,8 +351,12 @@
 structure AtomOptionKey = OptionKeyFn(AtomExpKey)
 
 val rec tablesOfQuery =
- fn Sql.Query1 {From = tablePairs, ...} => SS.fromList (map #1 tablePairs)
+ fn Sql.Query1 {From = fitems, ...} => List.foldl SS.union SS.empty (map tableOfFitem fitems)
   | Sql.Union (q1, q2) => SS.union (tablesOfQuery q1, tablesOfQuery q2)
+and tableOfFitem =
+ fn Sql.Table (t, _) => SS.singleton t
+  | Sql.Nested (q, _) => tablesOfQuery q
+  | Sql.Join (_, f1, f2, _) => SS.union (tableOfFitem f1, tableOfFitem f2)
 
 val tableOfDml =
  fn Sql.Insert (tab, _) => tab
@@ -489,43 +495,60 @@
 
     (* Need lift', etc. because we don't have rank-2 polymorphism. This should
        probably use a functor (an ML one, not Haskell) but works for now. *)
-    fun traverseSqexp (pure, _, lift, _, lift'', lift2, _) f =
+    fun traverseSqexp (pure, _, _, _, lift, lift', _, _, lift2, _, _, _, _, _) f =
         let
             val rec tr =
              fn Sql.SqNot se => lift Sql.SqNot (tr se)
               | Sql.Binop (r, se1, se2) =>
                 lift2 (fn (trse1, trse2) => Sql.Binop (r, trse1, trse2)) (tr se1, tr se2)
               | Sql.SqKnown se => lift Sql.SqKnown (tr se)
-              | Sql.Inj (e', loc) => lift'' (fn fe' => Sql.Inj (fe', loc)) (f e')
+              | Sql.Inj (e', loc) => lift' (fn fe' => Sql.Inj (fe', loc)) (f e')
               | Sql.SqFunc (s, se) => lift (fn trse => Sql.SqFunc (s, trse)) (tr se)
               | se => pure se
         in
             tr
         end
 
-    fun traverseQuery (ops as (_, pure', _, lift', _, _, lift2')) f =
+    fun traverseFitem (ops as (_, _, _, pure''', _, _, _, lift''', _, _, _, _, lift2'''', lift2''''')) f =
         let
-            val rec mp =
+            val rec tr =
+             fn Sql.Table t => pure''' (Sql.Table t)
+              | Sql.Join (jt, fi1, fi2, se) =>
+                lift2'''' (fn ((trfi1, trfi2), trse) => Sql.Join (jt, trfi1, trfi2, trse))
+                          (lift2''''' id (tr fi1, tr fi2), traverseSqexp ops f se)
+              | Sql.Nested (q, s) => lift''' (fn trq => Sql.Nested (trq, s))
+                                             (traverseQuery ops f q)
+        in
+            tr
+        end
+
+    and traverseQuery (ops as (_, pure', pure'', _, _, _, lift'', _, _, lift2', lift2'', lift2''', _, _)) f =
+        let
+            val rec seqList =
+             fn [] => pure'' []
+              | (x::xs) => lift2''' op:: (x, seqList xs)
+            val rec tr =
              fn Sql.Query1 q =>
-                (case #Where q of
-                     NONE => pure' (Sql.Query1 q)
-                   | SOME se =>
-                     lift' (fn mpse => Sql.Query1 {Select = #Select q,
-                                                   From = #From q,
-                                                   Where = SOME mpse})
-                           (traverseSqexp ops f se))
-              | Sql.Union (q1, q2) => lift2' Sql.Union (mp q1, mp q2)
+                (* TODO: make sure we don't need to traverse [#Select q]. *)
+                lift2' (fn (trfrom, trwher) => Sql.Query1 {Select = #Select q,
+                                                           From = trfrom,
+                                                           Where = trwher})
+                       (seqList (map (traverseFitem ops f) (#From q)),
+                        case #Where q of
+                            NONE => pure' NONE
+                          | SOME se => lift'' SOME (traverseSqexp ops f se))
+              | Sql.Union (q1, q2) => lift2'' Sql.Union (tr q1, tr q2)
         in
-            mp
+            tr
         end
 
     (* Include unused tuple elements in argument for convenience of using same
        argument as [traverseQuery]. *)
-    fun traverseIM (pure, _, _, _, _, lift2, _) f =
+    fun traverseIM (pure, _, _, _, _, _, _, _, _, lift2, _, _, _, _) f =
         IM.foldli (fn (k, v, acc) => lift2 (fn (acc, w) => IM.insert (acc, k, w)) (acc, f (k,v)))
                   (pure IM.empty)
 
-    fun traverseSubst (ops as (_, pure', lift, _, _, _, lift2')) f =
+    fun traverseSubst (ops as (_, pure', _, _, lift, _, _, _, _, lift2', _, _, _, _)) f =
         let
             fun mp ((n, fields), sqlify) =
                 lift (fn ((n', fields'), sqlify') =>
@@ -546,11 +569,14 @@
             traverseIM ops (fn (_, v) => mp v)
         end
 
-    fun monoidOps plus zero = (fn _ => zero, fn _ => zero,
-                               fn _ => fn x => x, fn _ => fn x => x, fn _ => fn x => x,
-                               fn _ => plus, fn _ => plus)
+    fun monoidOps plus zero =
+        (fn _ => zero, fn _ => zero, fn _ => zero, fn _ => zero,
+         fn _ => fn x => x, fn _ => fn x => x, fn _ => fn x => x, fn _ => fn x => x,
+         fn _ => plus, fn _ => plus, fn _ => plus, fn _ => plus, fn _ => plus, fn _ => plus)
 
-    val optionOps = (SOME, SOME, omap, omap, omap, omap2, omap2)
+    val optionOps = (SOME, SOME, SOME, SOME,
+                     omap, omap, omap, omap,
+                     omap2, omap2, omap2, omap2, omap2, omap2)
 
     fun foldMapQuery plus zero = traverseQuery (monoidOps plus zero)
     val omapQuery = traverseQuery optionOps
@@ -727,7 +753,7 @@
   | Sql.Null => raise Fail "Sqlcache: sqexpToFormula (Null)"
 
 fun mapSqexpFields f =
-    fn Sql.Field (t, v) => f (t, v)
+ fn Sql.Field (t, v) => f (t, v)
   | Sql.SqNot e => Sql.SqNot (mapSqexpFields f e)
   | Sql.Binop (r, e1, e2) => Sql.Binop (r, mapSqexpFields f e1, mapSqexpFields f e2)
   | Sql.SqKnown e => Sql.SqKnown (mapSqexpFields f e)
@@ -744,12 +770,102 @@
         mapSqexpFields (fn (t, f) => Sql.Field (rename t, f))
     end
 
-fun queryToFormula marker =
- fn Sql.Query1 {Select = sitems, From = tablePairs, Where = wher} =>
+structure FlattenQuery = struct
+
+    datatype substitution = RenameTable of string | SubstituteExp of Sql.sqexp SM.map
+
+    fun applySubst substTable =
+        let
+            fun substitute (table, field) =
+                case SM.find (substTable, table) of
+                    NONE => Sql.Field (table, field)
+                  | SOME (RenameTable realTable) => Sql.Field (realTable, field)
+                  | SOME (SubstituteExp substField) =>
+                    case SM.find (substField, field) of
+                        NONE => raise Fail "Sqlcache: applySubst"
+                      | SOME se => se
+        in
+            mapSqexpFields substitute
+        end
+
+    fun addToSubst (substTable, table, substField) =
+        SM.insert (substTable,
+                   table,
+                   case substField of
+                       RenameTable _ => substField
+                     | SubstituteExp subst => SubstituteExp (SM.map (applySubst substTable) subst))
+
+    fun newSubst (t, s) = addToSubst (SM.empty, t, s)
+
+    datatype sitem' = Named of Sql.sqexp * string | Unnamed of Sql.sqexp
+
+    type queryFlat = {Select : sitem' list, Where : Sql.sqexp}
+
+    val sitemsToSubst =
+        List.foldl (fn (Named (se, s), acc) => SM.insert (acc, s, se)
+                     | (Unnamed _, _) => raise Fail "Sqlcache: sitemsToSubst")
+                   SM.empty
+
+    fun unionSubst (s1, s2) = SM.unionWith (fn _ => raise Fail "Sqlcache: unionSubst") (s1, s2)
+
+    fun sqlAnd (se1, se2) = Sql.Binop (Sql.RLop Sql.And, se1, se2)
+
+    val rec flattenFitem : Sql.fitem -> (Sql.sqexp * substitution SM.map) list =
+     fn Sql.Table (real, alias) => [(Sql.SqTrue, newSubst (alias, RenameTable real))]
+      | Sql.Nested (q, s) =>
+        let
+            val qfs = flattenQuery q
+        in
+            map (fn (qf, subst) =>
+                    (#Where qf, addToSubst (subst, s, SubstituteExp (sitemsToSubst (#Select qf)))))
+                qfs
+        end
+      | Sql.Join (jt, fi1, fi2, se) =>
+        concatMap (fn ((wher1, subst1)) =>
+                      map (fn (wher2, subst2) =>
+                              (sqlAnd (wher1, wher2),
+                               (* There should be no name conflicts... Ziv hopes? *)
+                               unionSubst (subst1, subst2)))
+                          (flattenFitem fi2))
+                  (flattenFitem fi1)
+
+    and flattenQuery : Sql.query -> (queryFlat * substitution SM.map) list =
+     fn Sql.Query1 q =>
+        let
+            val fifss = cartesianProduct (map flattenFitem (#From q))
+        in
+            map (fn fifs =>
+                    let
+                        val subst = List.foldl (fn ((_, subst), acc) => unionSubst (acc, subst))
+                                               SM.empty
+                                               fifs
+                        val wher = List.foldr (fn ((wher, _), acc) => sqlAnd (wher, acc))
+                                              (case #Where q of
+                                                   NONE => Sql.SqTrue
+                                                 | SOME wher => wher)
+                                              fifs
+                    in
+                        (* ASK: do we actually need to pass the substitution through here? *)
+                        (* We use the substitution later, but it's not clear we
+                       need any of its currently present fields again. *)
+                        ({Select = map (fn Sql.SqExp (se, s) => Named (applySubst subst se, s)
+                                         | Sql.SqField tf =>
+                                           Unnamed (applySubst subst (Sql.Field tf)))
+                                       (#Select q),
+                          Where = applySubst subst wher},
+                         subst)
+                    end)
+                fifss
+        end
+      | Sql.Union (q1, q2) => (flattenQuery q1) @ (flattenQuery q2)
+
+end
+
+val flattenQuery = map #1 o FlattenQuery.flattenQuery
+
+fun queryFlatToFormula marker {Select = sitems, Where = wher} =
     let
-        val fWhere = case wher of
-                         NONE => Combo (Conj, [])
-                       | SOME e => sqexpToFormula (renameTables tablePairs e)
+        val fWhere = sqexpToFormula wher
     in
         case marker of
              NONE => fWhere
@@ -757,10 +873,10 @@
              let
                  val fWhereMarked = mapFormulaExps markFields fWhere
                  val toSqexp =
-                  fn Sql.SqField tf => Sql.Field tf
-                   | Sql.SqExp (se, _) => se
+                  fn FlattenQuery.Named (se, _) => se
+                   | FlattenQuery.Unnamed se => se
                  fun ineq se = Atom (Sql.Ne, se, markFields se)
-                 val fIneqs = Combo (Disj, map (ineq o renameTables tablePairs o toSqexp) sitems)
+                 val fIneqs = Combo (Disj, map (ineq o toSqexp) sitems)
              in
                  (Combo (Conj,
                          [fWhere,
@@ -769,7 +885,8 @@
                                   Combo (Conj, [fWhereMarked, fIneqs])])]))
              end
     end
-  | Sql.Union (q1, q2) => Combo (Disj, [queryToFormula marker q1, queryToFormula marker q2])
+
+fun queryToFormula marker q = Combo (Disj, map (queryFlatToFormula marker) (flattenQuery q))
 
 fun valsToFormula (markLeft, markRight) (table, vals) =
     Combo (Conj,
@@ -828,7 +945,7 @@
               (* If we don't know one side of the comparision, not a contradiction. *)
               | _ => false
         in
-            not (List.exists contradiction atoms) <\oguard\> SOME (UF.classes uf)
+            not (List.exists contradiction atoms) <\oguard\> (fn _ => SOME (UF.classes uf))
         end
 
     fun addToEqs (eqs, n, e) =
@@ -906,10 +1023,11 @@
         mapFormula (toAtomExps DmlRel)
 
     (* No eqs should have key conflicts because no variable is in two
-       equivalence classes, so the [#1] could be [#2]. *)
+       equivalence classes. *)
     val mergeEqs : (atomExp IntBinaryMap.map option list
                     -> atomExp IntBinaryMap.map option) =
-        List.foldr (omap2 (IM.unionWith #1)) (SOME IM.empty)
+        List.foldr (omap2 (IM.unionWith (fn _ => raise Fail "Sqlcache: ConflictMaps.mergeEqs")))
+                   (SOME IM.empty)
 
     val simplify =
         map TS.listItems
@@ -1008,12 +1126,16 @@
 fun fileMap doExp file = #1 (fileAllMapfoldB (fn _ => fn e => fn _ => (doExp e, ())) file ())
 
 (* TODO: make this a bit prettier.... *)
+(* TODO: factour out identical subexpressions to the same variable.... *)
 val simplifySql =
     let
         fun factorOutNontrivial text =
             let
                 val loc = dummyLoc
-                fun strcat (e1, e2) = (EStrcat (e1, e2), loc)
+                val strcat =
+                 fn (e1, (EPrim (Prim.String (Prim.Normal, "")), _)) => e1
+                  | ((EPrim (Prim.String (Prim.Normal, "")), _), e2) => e2
+                  | (e1, e2) => (EStrcat (e1, e2), loc)
                 val chunks = Sql.chunkify text
                 val (newText, newVariables) =
                     (* Important that this is foldr (to oppose foldl below). *)
@@ -1193,7 +1315,7 @@
     end
 
 fun cacheExp (env, exp', invalInfo, state : state) =
-    case worthCaching exp' <\oguard\> typOfExp' env exp' of
+    case worthCaching exp' <\oguard\> (fn _ => typOfExp' env exp') of
         NONE => NONE
       | SOME (TFun _, _) => NONE
       | SOME typ =>
@@ -1202,26 +1324,28 @@
         in
             shouldConsolidate args
             <\oguard\>
-             List.foldr (fn (arg, acc) =>
-                            acc
-                            <\obind\>
-                             (fn args' =>
-                                 (case arg of
-                                      AsIs exp => SOME exp
-                                    | Urlify exp =>
-                                      typOfExp env exp
-                                      <\obind\>
-                                       (fn typ => (MonoFooify.urlify env (exp, typ))))
-                                 <\obind\>
-                                  (fn arg' => SOME (arg' :: args'))))
-                        (SOME [])
-                        args
-            <\obind\>
-             (fn args' =>
-                 cacheWrap (env, (exp', dummyLoc), typ, args', #index state)
+             (fn _ =>
+                 List.foldr (fn (arg, acc) =>
+                                acc
+                                <\obind\>
+                                 (fn args' =>
+                                     (case arg of
+                                          AsIs exp => SOME exp
+                                        | Urlify exp =>
+                                          typOfExp env exp
+                                          <\obind\>
+                                           (fn typ => (MonoFooify.urlify env (exp, typ))))
+                                     <\obind\>
+                                      (fn arg' => SOME (arg' :: args'))))
+                            (SOME [])
+                            args
                  <\obind\>
-                  (fn cachedExp =>
-                      SOME (cachedExp, InvalInfo.updateState (invalInfo, length args', state))))
+                  (fn args' =>
+                      cacheWrap (env, (exp', dummyLoc), typ, args', #index state)
+                      <\obind\>
+                       (fn cachedExp =>
+                           SOME (cachedExp,
+                                 InvalInfo.updateState (invalInfo, length args', state)))))
         end
 
 fun cacheQuery (effs, env, q) : subexp =
@@ -1238,20 +1362,22 @@
         val {query = queryText, initial, body, ...} = q
         val attempt =
             (* Ziv misses Haskell's do notation.... *)
-            (safe 0 queryText andalso safe 0 initial andalso safe 2 body)
+            (safe 0 (printExp "attempt" queryText) andalso safe 0 initial andalso safe 2 body)
             <\oguard\>
-            Sql.parse Sql.query queryText
-            <\obind\>
-            (fn queryParsed =>
-                let
-                    val invalInfo = InvalInfo.singleton queryParsed
-                    fun mkExp state =
-                        case cacheExp (env, EQuery q, invalInfo, state) of
-                            NONE => ((EQuery q, dummyLoc), state)
-                          | SOME (cachedExp, state) => ((cachedExp, dummyLoc), state)
-                in
-                    SOME (Cachable (invalInfo, mkExp))
-                end)
+             (fn _ =>
+                 Sql.parse Sql.query (printExp "safe" queryText)
+                 <\obind\>
+                  (fn queryParsed =>
+                      let
+                          val _ = (printExp "parsed" queryText)
+                          val invalInfo = InvalInfo.singleton queryParsed
+                          fun mkExp state =
+                              case cacheExp (env, EQuery q, invalInfo, state) of
+                                  NONE => ((EQuery q, dummyLoc), state)
+                                | SOME (cachedExp, state) => ((cachedExp, dummyLoc), state)
+                      in
+                          SOME (Cachable (invalInfo, mkExp))
+                      end))
     in
         case attempt of
             NONE => Impure (EQuery q, dummyLoc)
@@ -1279,16 +1405,16 @@
                                               InvalInfo.unbind (invalInfoOfSubexp subexp, unbinds))
                                           (subexps, args)))
                              <\obind\>
-                             (fn invalInfo =>
-                                 SOME (Cachable (invalInfo,
-                                                 fn state =>
-                                                    case cacheExp (env,
-                                                                   f (map (#2 o #1) args),
-                                                                   invalInfo,
-                                                                   state) of
-                                                        NONE => mkExp state
-                                                      | SOME (e', state) => ((e', loc), state)),
-                                       state))
+                              (fn invalInfo =>
+                                  SOME (Cachable (invalInfo,
+                                                  fn state =>
+                                                     case cacheExp (env,
+                                                                    f (map (#2 o #1) args),
+                                                                    invalInfo,
+                                                                    state) of
+                                                         NONE => mkExp state
+                                                       | SOME (e', state) => ((e', loc), state)),
+                                        state))
             in
                 case attempt of
                     SOME (subexp, state) => (subexp, state)
@@ -1384,7 +1510,7 @@
                                DmlRel n => ERel n
                              | Prim p => EPrim p
                              (* TODO: make new type containing only these two. *)
-                             | _ => raise Fail "Sqlcache: optionAtomExpToExp",
+                             | _ => raise Fail "Sqlcache: Invalidations.optionAtomExpToExp",
                            loc)),
                    loc)
 
@@ -1506,8 +1632,8 @@
                 ListMergeSort.sort (fn ((i, _), (j, _)) => i > j) ls
             end
         fun locksOfName n =
-            lockList {store = IIMM.findSet (#flush lockMap, n),
-                      flush =IIMM.findSet (#store lockMap, n)}
+            lockList {flush = IIMM.findSet (#flush lockMap, n),
+                      store = IIMM.findSet (#store lockMap, n)}
         val locksOfExp = lockList o locksNeeded lockMap
         val expts = exports file
         fun doVal (v as (x, n, t, exp, s)) =