diff src/sqlcache.sml @ 2268:bc1ef958d801

Thread state through addCaching more carefully.
author Ziv Scully <ziv@mit.edu>
date Wed, 14 Oct 2015 23:10:10 -0400
parents e5b7b066bf1b
children f7bc7c11a656
line wrap: on
line diff
--- a/src/sqlcache.sml	Wed Oct 14 20:40:57 2015 -0400
+++ b/src/sqlcache.sml	Wed Oct 14 23:10:10 2015 -0400
@@ -15,12 +15,12 @@
                     then x
                     else iterate f (n-1) (f x)
 
-(* Filled in by [cacheWrap]. *)
-val ffiInfo : {index : int, params : int} list ref = ref []
+(* Filled in by [addFlushing]. *)
+val ffiInfoRef : {index : int, params : int} list ref = ref []
 
-fun resetFfiInfo () = ffiInfo := []
+fun resetFfiInfo () = ffiInfoRef := []
 
-fun getFfiInfo () = !ffiInfo
+fun getFfiInfo () = !ffiInfoRef
 
 (* Some FFIs have writing as their only effect, which the caching records. *)
 val ffiEffectful =
@@ -61,8 +61,6 @@
 (***********************)
 
 (* From the MLton wiki. *)
-infix  3 <\     fun x <\ f = fn y => f (x, y)     (* Left section      *)
-infix  3 \>     fun f \> y = f y                  (* Left application  *)
 infixr 3 />     fun f /> y = fn x => f (x, y)     (* Right section     *)
 infixr 3 </     fun x </ f = f x                  (* Right application *)
 
@@ -70,6 +68,9 @@
 fun obind (x, f) = Option.mapPartial f x
 fun oguard (b, x) = if b then x else NONE
 
+fun mapFst f (x, y) = (f x, y)
+
+
 (*******************)
 (* Effect Analysis *)
 (*******************)
@@ -699,7 +700,7 @@
 
 *)
 
-fun cacheWrap (env, exp, resultTyp, args, i) =
+fun cacheWrap (env, exp, resultTyp, args, state as (_, _, ffiInfo, index)) =
     let
         val loc = dummyLoc
         val rel0 = (ERel 0, loc)
@@ -708,21 +709,24 @@
             NONE => NONE
           | SOME urlified =>
             let
-                val () = ffiInfo := {index = i, params = length args} :: !ffiInfo
                 (* We ensure before this step that all arguments aren't effectful.
                    by turning them into local variables as needed. *)
                 val argsInc = map (incRels 1) args
-                val check = (check (i, args), loc)
-                val store = (store (i, argsInc, urlified), loc)
+                val check = (check (index, args), loc)
+                val store = (store (index, argsInc, urlified), loc)
             in
-                SOME (ECase
-                          (check,
-                           [((PNone stringTyp, loc),
-                             (ELet (varName "q", resultTyp, exp, (ESeq (store, rel0), loc)), loc)),
-                            ((PSome (stringTyp, (PVar (varName "hit", stringTyp), loc)), loc),
-                             (* Boolean is false because we're not unurlifying from a cookie. *)
-                             (EUnurlify (rel0, resultTyp, false), loc))],
-                           {disc = (TOption stringTyp, loc), result = resultTyp}))
+                SOME ((ECase
+                           (check,
+                            [((PNone stringTyp, loc),
+                              (ELet (varName "q", resultTyp, exp, (ESeq (store, rel0), loc)), loc)),
+                             ((PSome (stringTyp, (PVar (varName "hit", stringTyp), loc)), loc),
+                              (* Boolean is false because we're not unurlifying from a cookie. *)
+                              (EUnurlify (rel0, resultTyp, false), loc))],
+                            {disc = (TOption stringTyp, loc), result = resultTyp})),
+                      (#1 state,
+                       #2 state,
+                       {index = index, params = length args} :: ffiInfo,
+                       index + 1))
             end
     end
 
@@ -748,28 +752,30 @@
 
 val expSize = MonoUtil.Exp.fold {typ = #2, exp = fn (_, n) => n+1} 0
 
-datatype subexp = Cachable of unit -> exp | Impure of exp
+type state = (SIMM.multimap
+              * (Sql.query * int) IntBinaryMap.map
+              * {index : int, params : int} list
+              * int)
+
+datatype subexp = Cachable of state -> (exp * state) | Impure of exp
 
 val isImpure =
  fn Cachable _ => false
   | Impure _ => true
 
-val expOfSubexp =
- fn Cachable f => f ()
-  | Impure e => e
+val runSubexp : subexp * state -> exp * state =
+ fn (Cachable f, state) => f state
+  | (Impure e, state) => (e, state)
 
 (* TODO: pick a number. *)
 val sizeWorthCaching = 5
 
-type state = (SIMM.multimap * (Sql.query * int) IntBinaryMap.map * int)
-
-fun incIndex (x, y, index) = (x, y, index+1)
-
-fun cacheQuery effs env (state as (tableToIndices, indexToQueryNumArgs, index)) =
- fn q as {query = queryText,
-          state = resultTyp,
-          initial, body, tables, exps} =>
+fun cacheQuery (effs, env, state, q) : (exp' * state) =
     let
+        val (tableToIndices, indexToQueryNumArgs, ffiInfo, index) = state
+        val {query = queryText,
+             state = resultTyp,
+             initial, body, tables, exps} = q
         val numArgs = maxFreeVar queryText + 1
         val queryExp = (EQuery q, dummyLoc)
         (* DEBUG *)
@@ -787,29 +793,27 @@
         val attempt =
             (* Ziv misses Haskell's do notation.... *)
             (safe 0 queryText andalso safe 0 initial andalso safe 2 body)
-            <\oguard\>
-             Sql.parse Sql.query queryText
-            <\obind\>
-             (fn queryParsed =>
-                 (cacheWrap (env, queryExp, resultTyp, args, index))
-                 <\obind\>
-                  (fn cachedExp =>
-                      SOME (cachedExp,
-                            (SS.foldr (fn (tab, qi) => SIMM.insert (qi, tab, index))
-                                      tableToIndices
-                                      (tablesQuery queryParsed),
-                             IM.insert (indexToQueryNumArgs, index, (queryParsed, numArgs)),
-                             index + 1))))
+            </oguard/>
+            Sql.parse Sql.query queryText
+            </obind/>
+            (fn queryParsed =>
+                (cacheWrap (env, queryExp, resultTyp, args, state))
+                    </obind/>
+                    (fn (cachedExp, state) =>
+                        SOME (cachedExp,
+                              (SS.foldr (fn (tab, qi) => SIMM.insert (qi, tab, index))
+                                        tableToIndices
+                                        (tablesQuery queryParsed),
+                               IM.insert (indexToQueryNumArgs, index, (queryParsed, numArgs)),
+                               #3 state,
+                               #4 state))))
     in
         case attempt of
             SOME pair => pair
-          (* Even in this case, we have to increment index to avoid some bug,
-             but I forget exactly what it is or why this helps. *)
-          (* TODO: just use a reference for current index.... *)
-          | NONE => (EQuery q, incIndex state)
+          | NONE => (EQuery q, state)
     end
 
-fun cachePure (env, exp', (_, _, index)) =
+fun cachePure (env, exp', state as (_, _, _, index)) =
     case (expSize (exp', dummyLoc) > sizeWorthCaching)
              </oguard/>
              typOfExp' env exp' of
@@ -825,22 +829,23 @@
                     (map (fn n => (n, #2 (MonoEnv.lookupERel env n)))
                          (freeVars (exp', dummyLoc))))
             </obind/>
-            (fn args => cacheWrap (env, (exp', dummyLoc), typ, args, index))
+            (fn args => cacheWrap (env, (exp', dummyLoc), typ, args, state))
 
-fun cache (effs : IS.set) ((env, exp as (exp', loc)), state) : subexp * state =
+fun cache (effs : IS.set) ((env, exp as (exp', loc)), state) =
     let
-        fun wrapBindN f (args : (MonoEnv.env * exp) list) =
+        fun wrapBindN (f : exp list -> exp') (args : (MonoEnv.env * exp) list) =
             let
                 val (subexps, state) = ListUtil.foldlMap (cache effs) state args
-                fun mkExp () = (f (map expOfSubexp subexps), loc)
+                fun mkExp state = mapFst (fn exps => (f exps, loc))
+                                         (ListUtil.foldlMap runSubexp state subexps)
             in
                 if List.exists isImpure subexps
-                then (Impure (mkExp ()), state)
-                else (Cachable (fn () => case cachePure (env, f (map #2 args), state) of
-                                             NONE => mkExp ()
-                                           | SOME e' => (e', loc)),
-                      (* Conservatively increment index. *)
-                      incIndex state)
+                then mapFst Impure (mkExp state)
+                else (Cachable (fn state =>
+                                   case cachePure (env, f (map #2 args), state) of
+                                       NONE => mkExp state
+                                     | SOME (e', state) => ((e', loc), state)),
+                      state)
             end
         fun wrapBind1 f arg =
             wrapBindN (fn [arg] => f arg | _ => raise Match) [arg]
@@ -887,30 +892,25 @@
           | EUnurlify (e, t, b) => wrap1 (fn e => EUnurlify (e, t, b)) e
           | EQuery q =>
             let
-                val (exp', state) = cacheQuery effs env state q
+                val (exp', state) = cacheQuery (effs, env, state, q)
             in
                 (Impure (exp', loc), state)
             end
           | _ => if effectful effs env exp
                  then (Impure exp, state)
-                 else (Cachable (fn () => (case cachePure (env, exp', state) of
-                                               NONE => exp'
-                                             | SOME e' => e',
-                                           loc)),
-                       incIndex state)
+                 else (Cachable (fn state =>
+                                    case cachePure (env, exp', state) of
+                                         NONE => ((exp', loc), state)
+                                       | SOME (exp', state) => ((exp', loc), state)),
+                       state)
     end
 
 fun addCaching file =
     let
         val effs = effectfulDecls file
-        fun doTopLevelExp env exp state =
-            let
-                val (subexp, state) = cache effs ((env, exp), state)
-            in
-                (expOfSubexp subexp, state)
-            end
+        fun doTopLevelExp env exp state = runSubexp (cache effs ((env, exp), state))
     in
-        ((fileTopLevelMapfoldB doTopLevelExp file (SIMM.empty, IM.empty, 0)), effs)
+        ((fileTopLevelMapfoldB doTopLevelExp file (SIMM.empty, IM.empty, [], 0)), effs)
     end
 
 
@@ -967,7 +967,7 @@
 (* val gunk : ((Sql.query * int) * Sql.dml) list ref = ref [] *)
 (* val gunk' : exp list ref = ref [] *)
 
-fun addFlushing ((file, (tableToIndices, indexToQueryNumArgs, index)), effs) =
+fun addFlushing ((file, (tableToIndices, indexToQueryNumArgs, ffiInfo, index)), effs) =
     let
         val flushes = List.concat
                       o map (fn (i, argss) => map (fn args => flush (i, args)) argss)
@@ -999,13 +999,14 @@
     in
         (* DEBUG *)
         (* gunk := []; *)
+        ffiInfoRef := ffiInfo;
         fileMap doExp file
     end
 
 
-(***************)
-(* Entry point *)
-(***************)
+(************************)
+(* Compiler Entry Point *)
+(************************)
 
 val inlineSql =
     let