changeset 1077:a3273bee05a9

Initial generalization of Especialize, with security bug known
author Adam Chlipala <adamc@hcoop.net>
date Tue, 15 Dec 2009 12:26:00 -0500
parents dcf98ae3c48d
children b9321bcefb42
files CHANGELOG src/especialize.sml tests/espec.ur tests/espec.urp tests/espec.urs
diffstat 5 files changed, 97 insertions(+), 26 deletions(-) [+]
line wrap: on
line diff
--- a/CHANGELOG	Tue Dec 15 11:11:49 2009 -0500
+++ b/CHANGELOG	Tue Dec 15 12:26:00 2009 -0500
@@ -7,6 +7,7 @@
 - Typing of SQL queries no longer exposes which tables were used in joins but
   had none of their fields projected
 - Tasks
+- Optimization improvements
 
 ========
 20091203
--- a/src/especialize.sml	Tue Dec 15 11:11:49 2009 -0500
+++ b/src/especialize.sml	Tue Dec 15 12:26:00 2009 -0500
@@ -79,14 +79,14 @@
         pof (0, ls)
     end
 
-fun squish fvs =
+fun squish (untouched, fvs) =
     U.Exp.mapB {kind = fn _ => fn k => k,
                 con = fn _ => fn c => c,
                 exp = fn bound => fn e =>
                                      case e of
                                          ERel x =>
                                          if x >= bound then
-                                             ERel (positionOf (x - bound, fvs) + bound)
+                                             ERel (positionOf (x - bound, fvs) + bound + untouched)
                                          else
                                              e
                                        | _ => e,
@@ -165,31 +165,29 @@
                                                                       | _ => false}
                             val loc = ErrorMsg.dummySpan
 
-                            fun findSplit (xs, typ, fxs, fvs) =
+                            fun findSplit (xs, typ, fxs, fvs, ts) =
                                 case (#1 typ, xs) of
                                     (TFun (dom, ran), e :: xs') =>
                                     if functionInside dom then
                                         findSplit (xs',
                                                    ran,
-                                                   e :: fxs,
-                                                   IS.union (fvs, freeVars e))
+                                                   (true, e) :: fxs,
+                                                   IS.union (fvs, freeVars e),
+                                                   ts)
                                     else
-                                        (rev fxs, xs, fvs)
-                                  | _ => (rev fxs, xs, fvs)
+                                        findSplit (xs', ran, (false, e) :: fxs, fvs, dom :: ts)
+                                  | _ => (List.revAppend (fxs, map (fn e => (false, e)) xs), fvs, rev ts)
 
-                            val (fxs, xs, fvs) = findSplit (xs, typ, [], IS.empty)
-
-                            val fxs' = map (squish (IS.listItems fvs)) fxs
-
-                            fun firstRel () =
-                                case fxs' of
-                                    (ERel _, _) :: _ => true
-                                  | _ => false
+                            val (xs, fvs, ts) = findSplit (xs, typ, [], IS.empty, [])
+                            val fxs = List.mapPartial (fn (true, e) => SOME e | _ => NONE) xs
+                            val untouched = length (List.filter (fn (false, _) => true | _ => false) xs)
+                            val squish = squish (untouched, IS.listItems fvs)
+                            val fxs' = map squish fxs
                         in
                             (*Print.preface ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs');*)
-                            if firstRel ()
-                               orelse List.all (fn (ERel _, _) => true
-                                                 | _ => false) fxs' then
+                            if List.all (fn (false, _) => true
+                                          | (true, (ERel _, _)) => true
+                                          | _ => false) xs then
                                 (e, st)
                             else
                                 case (KM.find (args, fxs'), SS.member (!mayNotSpec, name)) of
@@ -198,7 +196,8 @@
                                         val e = (ENamed f', loc)
                                         val e = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
                                                          e fvs
-                                        val e = foldl (fn (arg, e) => (EApp (e, arg), loc))
+                                        val e = foldl (fn ((false, arg), e) => (EApp (e, arg), loc)
+                                                        | (_, e) => e)
                                                       e xs
                                     in
                                         (*Print.prefaces "Brand new (reuse)"
@@ -220,20 +219,24 @@
                                                                 [("fxs'",
                                                                   Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*)
 
-                                        fun subBody (body, typ, fxs') =
-                                            case (#1 body, #1 typ, fxs') of
+                                        fun subBody (body, typ, xs) =
+                                            case (#1 body, #1 typ, xs) of
                                                 (_, _, []) => SOME (body, typ)
-                                              | (EAbs (_, _, _, body'), TFun (_, typ'), x :: fxs'') =>
+                                              | (EAbs (_, _, _, body'), TFun (_, typ'), (b, x) :: xs) =>
                                                 let
-                                                    val body'' = E.subExpInExp (0, x) body'
+                                                    val body'' =
+                                                        if b then
+                                                            E.subExpInExp (0, squish x) body'
+                                                        else
+                                                            body'
                                                 in
                                                     subBody (body'',
                                                              typ',
-                                                             fxs'')
+                                                             xs)
                                                 end
                                               | _ => NONE
                                     in
-                                        case subBody (body, typ, fxs') of
+                                        case subBody (body, typ, xs) of
                                             NONE => (e, st)
                                           | SOME (body', typ') =>
                                             let
@@ -257,6 +260,12 @@
                                                                          ("fxs'", Print.p_list
                                                                                       (CorePrint.p_exp E.empty) fxs'),
                                                                          ("e", CorePrint.p_exp env (e, loc))]*)
+
+                                                val (body', typ') = foldr (fn (t, (body', typ')) =>
+                                                                              ((EAbs ("x", t, typ', body'), loc),
+                                                                               (TFun (t, typ'), loc)))
+                                                                          (body', typ') ts
+
                                                 val (body', typ') = IS.foldl (fn (n, (body', typ')) =>
                                                                                  let
                                                                                      val (x, xt) = List.nth (env, n)
@@ -275,7 +284,8 @@
                                                 val e' = (ENamed f', loc)
                                                 val e' = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
                                                                   e' fvs
-                                                val e' = foldl (fn (arg, e) => (EApp (e, arg), loc))
+                                                val e' = foldl (fn ((false, arg), e) => (EApp (e, arg), loc)
+                                                                 | (_, e) => e)
                                                                e' xs
                                                 (*val () = Print.prefaces "Brand new"
                                                                         [("e'", CorePrint.p_exp CoreEnv.empty e'),
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/espec.ur	Tue Dec 15 12:26:00 2009 -0500
@@ -0,0 +1,56 @@
+fun foo (wrap : xbody -> transaction page) = wrap <xml>
+  <a link={foo wrap}>Foo</a>
+</xml>
+
+fun bar (wrap : xbody -> transaction page) (n : int) = wrap <xml>
+  <a link={bar wrap n}>Bar</a>; {[n]}
+</xml>
+
+fun baz (n : int) (wrap : xbody -> transaction page) = wrap <xml>
+  <a link={baz n wrap}>Baz</a>; {[n]}
+</xml>
+
+fun middle (n : int) (wrap : xbody -> transaction page) (m : int) = wrap <xml>
+  <a link={middle n wrap m}>Middle</a>; {[n]}; {[m]}
+</xml>
+
+fun crazy (f : int -> int) (b : bool) (wrap : xbody -> transaction page) (m : int) = wrap <xml>
+  <a link={crazy f b wrap m}>Crazy</a>; {[b]}; {[f m]}
+</xml>
+
+fun wild (q : bool) (f : int -> int) (n : float) (wrap : xbody -> transaction page) (m : int) = wrap <xml>
+  <a link={wild q f n wrap m}>Wild</a>; {[n]}; {[f m]}; {[q]}
+</xml>
+
+fun wrap x = return <xml><body>{x}</body></xml>
+
+fun wrapN n x = return <xml><body>{[n]}; {x}</body></xml>
+
+fun foo2 (wrap : xbody -> transaction page) = wrap <xml>
+  <a link={foo2 wrap}>Foo</a>
+</xml>
+
+fun foo3 (n : int) = wrap <xml>
+  <a link={foo2 (wrapN n)}>Foo</a>
+</xml>
+
+fun bar2 (n : int) (wrap : xbody -> transaction page) = wrap <xml>
+  <a link={bar2 n wrap}>Bar</a>; n={[n]}
+</xml>
+
+fun bar3 (n : int) = wrap <xml>
+  <a link={bar2 88 (wrapN n)}>Bar</a>
+</xml>
+
+
+fun main () = return <xml><body>
+  <a link={foo wrap}>Foo</a>
+  <a link={bar wrap 32}>Bar</a>
+  <a link={baz 18 wrap}>Baz</a>
+  <a link={middle 1 wrap 2}>Middle</a>
+  <a link={crazy (fn n => 2 * n) False wrap 2}>Crazy</a>
+  <a link={wild True (fn n => 2 * n) 1.23 wrap 2}>Wild</a>
+  <hr/>
+  <a link={foo3 15}>Foo3</a>
+  <a link={bar3 44}>Bar3</a>
+</body></xml>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/espec.urp	Tue Dec 15 12:26:00 2009 -0500
@@ -0,0 +1,3 @@
+debug
+
+espec
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/espec.urs	Tue Dec 15 12:26:00 2009 -0500
@@ -0,0 +1,1 @@
+val main : unit -> transaction page