changeset 1861:52043ad66ce7

Extend Especialize rule: find maximal argument prefixes that end in 1 or more arguments with functional types
author Adam Chlipala <adam@chlipala.net>
date Fri, 09 Aug 2013 16:04:16 -0400
parents d54984564bcd
children a3d795fbecb9
files src/especialize.sml
diffstat 1 files changed, 35 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/src/especialize.sml	Wed Jul 17 10:48:31 2013 -0400
+++ b/src/especialize.sml	Fri Aug 09 16:04:16 2013 -0400
@@ -364,30 +364,42 @@
                         let
                             val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs
 
-                            (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty
-                                                                                      (e, ErrorMsg.dummySpan))]*)
+                            (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty e)]*)
 
                             val loc = ErrorMsg.dummySpan
 
                             val oldXs = xs
 
-                            fun findSplit av (constArgs, xs, typ, fxs, fvs) =
-                                case (#1 typ, xs) of
-                                    (TFun (dom, ran), e :: xs') =>
-                                    if constArgs > 0 then
-                                        if functionInside dom then
-                                            (rev (e :: fxs), xs', IS.union (fvs, freeVars e))
+                            fun findSplit av (initialPart, constArgs, xs, typ, fxs, fvs) =
+                                let
+                                    fun default () =
+                                        if initialPart then
+                                            ([], oldXs, IS.empty)
                                         else
-                                            findSplit av (constArgs - 1,
-                                                          xs',
-                                                          ran,
-                                                          e :: fxs,
-                                                          IS.union (fvs, freeVars e))
-                                    else
-                                        ([], oldXs, IS.empty)
-                                  | _ => ([], oldXs, IS.empty)
+                                            (rev fxs, xs, fvs)
+                                in
+                                    case (#1 typ, xs) of
+                                        (TFun (dom, ran), e :: xs') =>
+                                        if constArgs > 0 then
+                                            let
+                                                val fi = functionInside dom
+                                            in
+                                                if initialPart orelse fi then
+                                                    findSplit av (not fi andalso initialPart,
+                                                                  constArgs - 1,
+                                                                  xs',
+                                                                  ran,
+                                                                  e :: fxs,
+                                                                  IS.union (fvs, freeVars e))
+                                                else
+                                                    default ()
+                                            end
+                                        else
+                                            default ()
+                                      | _ => default ()
+                                end
 
-                            val (fxs, xs, fvs) = findSplit true (constArgs, xs, typ, [], IS.empty)
+                            val (fxs, xs, fvs) = findSplit true (true, constArgs, xs, typ, [], IS.empty)
 
                             val vts = map (fn n => #2 (List.nth (env, n))) (IS.listItems fvs)
                             val fxs' = map (squish (IS.listItems fvs)) fxs
@@ -483,7 +495,7 @@
                                                                                       (TFun (xt, typ'), loc))
                                                                                  end)
                                                                              (body', typ') fvs
-                                                (*val () = print ("NEW: " ^ name ^ "__" ^ Int.toString f' ^ "\n");*)
+                                                (*val () = print ("NEW: " ^ name ^ "__" ^ Int.toString f' ^ "\n")*)
                                                 val body' = ReduceLocal.reduceExp body'
                                                 (*val () = Print.preface ("PRE", CorePrint.p_exp CoreEnv.empty body')*)
                                                 val (body', st) = exp (env, body', st)
@@ -523,6 +535,8 @@
                                                       Int.min (constArgs, calcConstArgs fs e))
                                                   maxInt vis
                         in
+                            (*Print.prefaces "ConstArgs" [("d", CorePrint.p_decl CoreEnv.empty d),
+                                                        ("ca", Print.PD.string (Int.toString constArgs))];*)
                             foldl (fn ((x, n, c, e, tag), funcs) =>
                                       IM.insert (funcs, n, {name = x,
                                                             args = KM.empty,
@@ -607,12 +621,14 @@
                 val funcs =
                     case #1 d of
                         DVal (x, n, c, e as (EAbs _, _), tag) =>
+                        ((*Print.prefaces "ConstArgs[2]" [("d", CorePrint.p_decl CoreEnv.empty d),
+                                                     ("ca", Print.PD.string (Int.toString (calcConstArgs (IS.singleton n) e)))];*)
                         IM.insert (funcs, n, {name = x,
                                               args = KM.empty,
                                               body = e,
                                               typ = c,
                                               tag = tag,
-                                              constArgs = calcConstArgs (IS.singleton n) e})
+                                              constArgs = calcConstArgs (IS.singleton n) e}))
                       | DVal (_, n, _, (ENamed n', _), _) =>
                         (case IM.find (funcs, n') of
                              NONE => funcs