diff src/sqlcache.sml @ 2267:e5b7b066bf1b

Factor out SQL simplification.
author Ziv Scully <ziv@mit.edu>
date Wed, 14 Oct 2015 20:40:57 -0400
parents afd12c75e0d6
children bc1ef958d801
line wrap: on
line diff
--- a/src/sqlcache.sml	Wed Oct 14 15:45:04 2015 -0400
+++ b/src/sqlcache.sml	Wed Oct 14 20:40:57 2015 -0400
@@ -555,47 +555,71 @@
 
 fun fileMap doExp file = #1 (fileAllMapfoldB (fn _ => fn e => fn _ => (doExp e, ())) file ())
 
-(* Takes a text expression and returns
-     newText: a new expression with any subexpressions that do computation
-         replaced with variables,
-     wrapLets: a function that wraps its argument expression with lets binding
-         those variables to their corresponding computations, and
-     numArgs: the number of such bindings.
-   The De Bruijn indices work out for [wrapLets (incRels numArgs newText)], but
-   the intention is that newText might be augmented. *)
-fun factorOutNontrivial text =
+(* TODO: make this a bit prettier.... *)
+val simplifySql =
     let
-        val loc = dummyLoc
-        fun strcat (e1, e2) = (EStrcat (e1, e2), loc)
-        val chunks = Sql.chunkify text
-        val (newText, newVariables) =
-            (* Important that this is foldr (to oppose foldl below). *)
-            List.foldr
-                (fn (chunk, (qText, newVars)) =>
-                    (* Variable bound to the head of newVars will have the lowest index. *)
-                    case chunk of
-                        (* EPrim should always be a string in this case. *)
-                        Sql.Exp (e as (EPrim _, _)) => (strcat (e, qText), newVars)
-                      | Sql.Exp e =>
-                        let
-                            val n = length newVars
-                        in
-                            (* This is the (n+1)th new variable, so there are
-                               already n new variables bound, so we increment
-                               indices by n. *)
-                            (strcat ((ERel (~(n+1)), loc), qText), incRels n e :: newVars)
-                        end
-                      | Sql.String s => (strcat (stringExp s, qText), newVars))
-                (stringExp "", [])
-                chunks
-        fun wrapLets e' =
-            (* Important that this is foldl (to oppose foldr above). *)
-            List.foldl (fn (v, e') => ELet (varName "sqlArg", stringTyp, v, (e', loc)))
-                       e'
-                       newVariables
-        val numArgs = length newVariables
+        fun factorOutNontrivial text =
+            let
+                val loc = dummyLoc
+                fun strcat (e1, e2) = (EStrcat (e1, e2), loc)
+                val chunks = Sql.chunkify text
+                val (newText, newVariables) =
+                    (* Important that this is foldr (to oppose foldl below). *)
+                    List.foldr
+                        (fn (chunk, (qText, newVars)) =>
+                            (* Variable bound to the head of newVars will have the lowest index. *)
+                            case chunk of
+                                (* EPrim should always be a string in this case. *)
+                                Sql.Exp (e as (EPrim _, _)) => (strcat (e, qText), newVars)
+                              | Sql.Exp e =>
+                                let
+                                    val n = length newVars
+                                in
+                                    (* This is the (n+1)th new variable, so there are
+                                       already n new variables bound, so we increment
+                                       indices by n. *)
+                                    (strcat ((ERel (~(n+1)), loc), qText), incRels n e :: newVars)
+                                end
+                              | Sql.String s => (strcat (stringExp s, qText), newVars))
+                        (stringExp "", [])
+                        chunks
+                fun wrapLets e' =
+                    (* Important that this is foldl (to oppose foldr above). *)
+                    List.foldl (fn (v, e') => ELet (varName "sqlArg", stringTyp, v, (e', loc)))
+                               e'
+                               newVariables
+                val numArgs = length newVariables
+            in
+                (newText, wrapLets, numArgs)
+            end
+        fun doExp exp' =
+            let
+                val text = case exp' of
+                               EQuery {query = text, ...} => text
+                             | EDml (text, _) => text
+                             | _ => raise Match
+                val (newText, wrapLets, numArgs) = factorOutNontrivial text
+                val newExp' = case exp' of
+                                 EQuery q => EQuery {query = newText,
+                                                     exps = #exps q,
+                                                     tables = #tables q,
+                                                     state = #state q,
+                                                     body = #body q,
+                                                     initial = #initial q}
+                               | EDml (_, failureMode) => EDml (newText, failureMode)
+                               | _ => raise Match
+            in
+                (* Increment once for each new variable just made. This is
+                   where we use the negative De Bruijn indices hack. *)
+                (* TODO: please don't use that hack. As anyone could have
+                   predicted, it was incomprehensible a year later.... *)
+                wrapLets (#1 (incRels numArgs (newExp', dummyLoc)))
+            end
     in
-        (newText, wrapLets, numArgs)
+        fileMap (fn exp' => case exp' of
+                                EQuery _ => doExp exp'
+                              | EDml _ => doExp exp'
+                              | _ => exp')
     end
 
 
@@ -659,6 +683,22 @@
 (* Caching *)
 (***********)
 
+(*
+
+To get the invalidations for a dml, we need (each <- is list-monad-y):
+  * table <- dml
+  * cache <- table
+  * query <- cache
+  * inval <- (query, dml),
+where inval is a list of query argument indices, so
+  * way to change query args in inval to cache args.
+For now, the last one is just
+  * a map from query arg number to the corresponding free variable (per query)
+  * a map from free variable to cache arg number (per cache).
+Both queries and caches should have IDs.
+
+*)
+
 fun cacheWrap (env, exp, resultTyp, args, i) =
     let
         val loc = dummyLoc
@@ -686,6 +726,14 @@
             end
     end
 
+val maxFreeVar =
+    MonoUtil.Exp.foldB
+        {typ = #2,
+         exp = fn (bound, ERel n, v) => Int.max (v, n - bound) | (_, _, v) => v,
+         bind = fn (bound, MonoUtil.Exp.RelE _) => bound + 1 | (bound, _) => bound}
+        0
+        ~1
+
 val freeVars =
     IS.listItems
     o MonoUtil.Exp.foldB
@@ -700,14 +748,14 @@
 
 val expSize = MonoUtil.Exp.fold {typ = #2, exp = fn (_, n) => n+1} 0
 
-datatype subexp = Pure of unit -> exp | Impure of exp
+datatype subexp = Cachable of unit -> exp | Impure of exp
 
 val isImpure =
- fn Pure _ => false
+ fn Cachable _ => false
   | Impure _ => true
 
 val expOfSubexp =
- fn Pure f => f ()
+ fn Cachable f => f ()
   | Impure e => e
 
 (* TODO: pick a number. *)
@@ -718,23 +766,12 @@
 fun incIndex (x, y, index) = (x, y, index+1)
 
 fun cacheQuery effs env (state as (tableToIndices, indexToQueryNumArgs, index)) =
-    fn q as {query = origQueryText,
-             state = resultTyp,
-             initial, body, tables, exps} =>
+ fn q as {query = queryText,
+          state = resultTyp,
+          initial, body, tables, exps} =>
     let
-        val (newQueryText, wrapLets, numArgs) = factorOutNontrivial origQueryText
-        (* Increment once for each new variable just made. This is where we
-           use the negative De Bruijn indices hack. *)
-        (* TODO: please don't use that hack. As anyone could have predicted, it
-           was incomprehensible a year later.... *)
-        val queryExp = incRels numArgs
-                               (EQuery {query = newQueryText,
-                                        state = resultTyp,
-                                        initial = initial,
-                                        body = body,
-                                        tables = tables,
-                                        exps = exps},
-                                dummyLoc)
+        val numArgs = maxFreeVar queryText + 1
+        val queryExp = (EQuery q, dummyLoc)
         (* DEBUG *)
         (* val () = Print.preface ("sqlcache> ", MonoPrint.p_exp MonoEnv.empty queryText) *)
         val args = List.tabulate (numArgs, fn n => (ERel n, dummyLoc))
@@ -747,26 +784,22 @@
                         (iterate (fn env => MonoEnv.pushERel env "_" dummyTyp NONE)
                                  bound
                                  env)
-        val textOfQuery = fn (EQuery {query, ...}, _) => SOME query | _ => NONE
         val attempt =
             (* Ziv misses Haskell's do notation.... *)
-            textOfQuery queryExp
+            (safe 0 queryText andalso safe 0 initial andalso safe 2 body)
+            <\oguard\>
+             Sql.parse Sql.query queryText
             <\obind\>
-             (fn queryText =>
-                 (safe 0 queryText andalso safe 0 initial andalso safe 2 body)
-                 <\oguard\>
-                  Sql.parse Sql.query queryText
+             (fn queryParsed =>
+                 (cacheWrap (env, queryExp, resultTyp, args, index))
                  <\obind\>
-                  (fn queryParsed =>
-                      (cacheWrap (env, queryExp, resultTyp, args, index))
-                      <\obind\>
-                       (fn cachedExp =>
-                           SOME (wrapLets cachedExp,
-                                 (SS.foldr (fn (tab, qi) => SIMM.insert (qi, tab, index))
-                                           tableToIndices
-                                           (tablesQuery queryParsed),
-                                  IM.insert (indexToQueryNumArgs, index, (queryParsed, numArgs)),
-                                  index + 1)))))
+                  (fn cachedExp =>
+                      SOME (cachedExp,
+                            (SS.foldr (fn (tab, qi) => SIMM.insert (qi, tab, index))
+                                      tableToIndices
+                                      (tablesQuery queryParsed),
+                             IM.insert (indexToQueryNumArgs, index, (queryParsed, numArgs)),
+                             index + 1))))
     in
         case attempt of
             SOME pair => pair
@@ -777,20 +810,20 @@
     end
 
 fun cachePure (env, exp', (_, _, index)) =
-    case typOfExp' env exp' of
+    case (expSize (exp', dummyLoc) > sizeWorthCaching)
+             </oguard/>
+             typOfExp' env exp' of
         NONE => NONE
       | SOME (TFun _, _) => NONE
       | SOME typ =>
-        (expSize (exp', dummyLoc) < sizeWorthCaching)
-            </oguard/>
-            (List.foldr (fn (_, NONE) => NONE
-                          | ((n, typ), SOME args) =>
-                            (MonoFooify.urlify env ((ERel n, dummyLoc), typ))
-                                </obind/>
-                                (fn arg => SOME (arg :: args)))
-                        (SOME [])
-                        (map (fn n => (n, #2 (MonoEnv.lookupERel env n)))
-                             (freeVars (exp', dummyLoc))))
+        (List.foldr (fn (_, NONE) => NONE
+                      | ((n, typ), SOME args) =>
+                        (MonoFooify.urlify env ((ERel n, dummyLoc), typ))
+                            </obind/>
+                            (fn arg => SOME (arg :: args)))
+                    (SOME [])
+                    (map (fn n => (n, #2 (MonoEnv.lookupERel env n)))
+                         (freeVars (exp', dummyLoc))))
             </obind/>
             (fn args => cacheWrap (env, (exp', dummyLoc), typ, args, index))
 
@@ -803,9 +836,9 @@
             in
                 if List.exists isImpure subexps
                 then (Impure (mkExp ()), state)
-                else (Pure (fn () => case cachePure (env, f (map #2 args), state) of
-                                         NONE => mkExp ()
-                                       | SOME e' => (e', loc)),
+                else (Cachable (fn () => case cachePure (env, f (map #2 args), state) of
+                                             NONE => mkExp ()
+                                           | SOME e' => (e', loc)),
                       (* Conservatively increment index. *)
                       incIndex state)
             end
@@ -860,10 +893,10 @@
             end
           | _ => if effectful effs env exp
                  then (Impure exp, state)
-                 else (Pure (fn () => (case cachePure (env, exp', state) of
-                                           NONE => exp'
-                                         | SOME e' => e',
-                                       loc)),
+                 else (Cachable (fn () => (case cachePure (env, exp', state) of
+                                               NONE => exp'
+                                             | SOME e' => e',
+                                           loc)),
                        incIndex state)
     end
 
@@ -939,14 +972,10 @@
         val flushes = List.concat
                       o map (fn (i, argss) => map (fn args => flush (i, args)) argss)
         val doExp =
-         fn EDml (origDmlText, failureMode) =>
+         fn dmlExp as EDml (dmlText, failureMode) =>
             let
                 (* DEBUG *)
                 (* val () = gunk' := origDmlText :: !gunk' *)
-                val (newDmlText, wrapLets, numArgs) = factorOutNontrivial origDmlText
-                val dmlText = incRels numArgs newDmlText
-                val dmlExp = EDml (dmlText, failureMode)
-                (* DEBUG *)
                 (* val () = Print.preface ("SQLCACHE: ", (MonoPrint.p_exp MonoEnv.empty origDmlText)) *)
                 val inval =
                     case Sql.parse Sql.dml dmlText of
@@ -964,7 +993,7 @@
                 case inval of
                     (* TODO: fail more gracefully. *)
                     NONE => raise Match
-                  | SOME invs => wrapLets (sequence (flushes invs @ [dmlExp]))
+                  | SOME invs => sequence (flushes invs @ [dmlExp])
             end
           | e' => e'
     in
@@ -1001,7 +1030,7 @@
         (datatypes @ newDecls @ others, sideInfo)
     end
 
-val go' = addFlushing o addCaching o inlineSql
+val go' = addFlushing o addCaching o simplifySql o inlineSql
 
 fun go file =
     let