# HG changeset patch # User Adam Chlipala # Date 1376078656 14400 # Node ID 52043ad66ce7204bb8fd000d998ac6cbc9fd55e4 # Parent d54984564bcd7540854e678dc5e58d069dcd7e09 Extend Especialize rule: find maximal argument prefixes that end in 1 or more arguments with functional types diff -r d54984564bcd -r 52043ad66ce7 src/especialize.sml --- a/src/especialize.sml Wed Jul 17 10:48:31 2013 -0400 +++ b/src/especialize.sml Fri Aug 09 16:04:16 2013 -0400 @@ -364,30 +364,42 @@ let val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs - (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty - (e, ErrorMsg.dummySpan))]*) + (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty e)]*) val loc = ErrorMsg.dummySpan val oldXs = xs - fun findSplit av (constArgs, xs, typ, fxs, fvs) = - case (#1 typ, xs) of - (TFun (dom, ran), e :: xs') => - if constArgs > 0 then - if functionInside dom then - (rev (e :: fxs), xs', IS.union (fvs, freeVars e)) + fun findSplit av (initialPart, constArgs, xs, typ, fxs, fvs) = + let + fun default () = + if initialPart then + ([], oldXs, IS.empty) else - findSplit av (constArgs - 1, - xs', - ran, - e :: fxs, - IS.union (fvs, freeVars e)) - else - ([], oldXs, IS.empty) - | _ => ([], oldXs, IS.empty) + (rev fxs, xs, fvs) + in + case (#1 typ, xs) of + (TFun (dom, ran), e :: xs') => + if constArgs > 0 then + let + val fi = functionInside dom + in + if initialPart orelse fi then + findSplit av (not fi andalso initialPart, + constArgs - 1, + xs', + ran, + e :: fxs, + IS.union (fvs, freeVars e)) + else + default () + end + else + default () + | _ => default () + end - val (fxs, xs, fvs) = findSplit true (constArgs, xs, typ, [], IS.empty) + val (fxs, xs, fvs) = findSplit true (true, constArgs, xs, typ, [], IS.empty) val vts = map (fn n => #2 (List.nth (env, n))) (IS.listItems fvs) val fxs' = map (squish (IS.listItems fvs)) fxs @@ -483,7 +495,7 @@ (TFun (xt, typ'), loc)) end) (body', typ') fvs - (*val () = print ("NEW: " ^ name ^ "__" ^ Int.toString f' ^ "\n");*) + (*val () = print ("NEW: " ^ name ^ "__" ^ Int.toString f' ^ "\n")*) val body' = ReduceLocal.reduceExp body' (*val () = Print.preface ("PRE", CorePrint.p_exp CoreEnv.empty body')*) val (body', st) = exp (env, body', st) @@ -523,6 +535,8 @@ Int.min (constArgs, calcConstArgs fs e)) maxInt vis in + (*Print.prefaces "ConstArgs" [("d", CorePrint.p_decl CoreEnv.empty d), + ("ca", Print.PD.string (Int.toString constArgs))];*) foldl (fn ((x, n, c, e, tag), funcs) => IM.insert (funcs, n, {name = x, args = KM.empty, @@ -607,12 +621,14 @@ val funcs = case #1 d of DVal (x, n, c, e as (EAbs _, _), tag) => + ((*Print.prefaces "ConstArgs[2]" [("d", CorePrint.p_decl CoreEnv.empty d), + ("ca", Print.PD.string (Int.toString (calcConstArgs (IS.singleton n) e)))];*) IM.insert (funcs, n, {name = x, args = KM.empty, body = e, typ = c, tag = tag, - constArgs = calcConstArgs (IS.singleton n) e}) + constArgs = calcConstArgs (IS.singleton n) e})) | DVal (_, n, _, (ENamed n', _), _) => (case IM.find (funcs, n') of NONE => funcs