diff src/especialize.sml @ 1077:a3273bee05a9

Initial generalization of Especialize, with security bug known
author Adam Chlipala <adamc@hcoop.net>
date Tue, 15 Dec 2009 12:26:00 -0500
parents 066493f7f008
children b9321bcefb42
line wrap: on
line diff
--- a/src/especialize.sml	Tue Dec 15 11:11:49 2009 -0500
+++ b/src/especialize.sml	Tue Dec 15 12:26:00 2009 -0500
@@ -79,14 +79,14 @@
         pof (0, ls)
     end
 
-fun squish fvs =
+fun squish (untouched, fvs) =
     U.Exp.mapB {kind = fn _ => fn k => k,
                 con = fn _ => fn c => c,
                 exp = fn bound => fn e =>
                                      case e of
                                          ERel x =>
                                          if x >= bound then
-                                             ERel (positionOf (x - bound, fvs) + bound)
+                                             ERel (positionOf (x - bound, fvs) + bound + untouched)
                                          else
                                              e
                                        | _ => e,
@@ -165,31 +165,29 @@
                                                                       | _ => false}
                             val loc = ErrorMsg.dummySpan
 
-                            fun findSplit (xs, typ, fxs, fvs) =
+                            fun findSplit (xs, typ, fxs, fvs, ts) =
                                 case (#1 typ, xs) of
                                     (TFun (dom, ran), e :: xs') =>
                                     if functionInside dom then
                                         findSplit (xs',
                                                    ran,
-                                                   e :: fxs,
-                                                   IS.union (fvs, freeVars e))
+                                                   (true, e) :: fxs,
+                                                   IS.union (fvs, freeVars e),
+                                                   ts)
                                     else
-                                        (rev fxs, xs, fvs)
-                                  | _ => (rev fxs, xs, fvs)
+                                        findSplit (xs', ran, (false, e) :: fxs, fvs, dom :: ts)
+                                  | _ => (List.revAppend (fxs, map (fn e => (false, e)) xs), fvs, rev ts)
 
-                            val (fxs, xs, fvs) = findSplit (xs, typ, [], IS.empty)
-
-                            val fxs' = map (squish (IS.listItems fvs)) fxs
-
-                            fun firstRel () =
-                                case fxs' of
-                                    (ERel _, _) :: _ => true
-                                  | _ => false
+                            val (xs, fvs, ts) = findSplit (xs, typ, [], IS.empty, [])
+                            val fxs = List.mapPartial (fn (true, e) => SOME e | _ => NONE) xs
+                            val untouched = length (List.filter (fn (false, _) => true | _ => false) xs)
+                            val squish = squish (untouched, IS.listItems fvs)
+                            val fxs' = map squish fxs
                         in
                             (*Print.preface ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs');*)
-                            if firstRel ()
-                               orelse List.all (fn (ERel _, _) => true
-                                                 | _ => false) fxs' then
+                            if List.all (fn (false, _) => true
+                                          | (true, (ERel _, _)) => true
+                                          | _ => false) xs then
                                 (e, st)
                             else
                                 case (KM.find (args, fxs'), SS.member (!mayNotSpec, name)) of
@@ -198,7 +196,8 @@
                                         val e = (ENamed f', loc)
                                         val e = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
                                                          e fvs
-                                        val e = foldl (fn (arg, e) => (EApp (e, arg), loc))
+                                        val e = foldl (fn ((false, arg), e) => (EApp (e, arg), loc)
+                                                        | (_, e) => e)
                                                       e xs
                                     in
                                         (*Print.prefaces "Brand new (reuse)"
@@ -220,20 +219,24 @@
                                                                 [("fxs'",
                                                                   Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*)
 
-                                        fun subBody (body, typ, fxs') =
-                                            case (#1 body, #1 typ, fxs') of
+                                        fun subBody (body, typ, xs) =
+                                            case (#1 body, #1 typ, xs) of
                                                 (_, _, []) => SOME (body, typ)
-                                              | (EAbs (_, _, _, body'), TFun (_, typ'), x :: fxs'') =>
+                                              | (EAbs (_, _, _, body'), TFun (_, typ'), (b, x) :: xs) =>
                                                 let
-                                                    val body'' = E.subExpInExp (0, x) body'
+                                                    val body'' =
+                                                        if b then
+                                                            E.subExpInExp (0, squish x) body'
+                                                        else
+                                                            body'
                                                 in
                                                     subBody (body'',
                                                              typ',
-                                                             fxs'')
+                                                             xs)
                                                 end
                                               | _ => NONE
                                     in
-                                        case subBody (body, typ, fxs') of
+                                        case subBody (body, typ, xs) of
                                             NONE => (e, st)
                                           | SOME (body', typ') =>
                                             let
@@ -257,6 +260,12 @@
                                                                          ("fxs'", Print.p_list
                                                                                       (CorePrint.p_exp E.empty) fxs'),
                                                                          ("e", CorePrint.p_exp env (e, loc))]*)
+
+                                                val (body', typ') = foldr (fn (t, (body', typ')) =>
+                                                                              ((EAbs ("x", t, typ', body'), loc),
+                                                                               (TFun (t, typ'), loc)))
+                                                                          (body', typ') ts
+
                                                 val (body', typ') = IS.foldl (fn (n, (body', typ')) =>
                                                                                  let
                                                                                      val (x, xt) = List.nth (env, n)
@@ -275,7 +284,8 @@
                                                 val e' = (ENamed f', loc)
                                                 val e' = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
                                                                   e' fvs
-                                                val e' = foldl (fn (arg, e) => (EApp (e, arg), loc))
+                                                val e' = foldl (fn ((false, arg), e) => (EApp (e, arg), loc)
+                                                                 | (_, e) => e)
                                                                e' xs
                                                 (*val () = Print.prefaces "Brand new"
                                                                         [("e'", CorePrint.p_exp CoreEnv.empty e'),