# HG changeset patch # User Adam Chlipala # Date 1326308015 18000 # Node ID 13dad713da354c8739f0cffb34b7965acaa921c3 # Parent 4cacced4a6da36e5460f7bc873f3e4d97464589c New, more principled heuristic for Especialize: only specialize uniform function arguments; that is, arguments that don't change across recursive calls diff -r 4cacced4a6da -r 13dad713da35 src/especialize.sml --- a/src/especialize.sml Wed Jan 11 11:08:48 2012 -0500 +++ b/src/especialize.sml Wed Jan 11 13:53:35 2012 -0500 @@ -109,7 +109,8 @@ args : int KM.map, body : exp, typ : con, - tag : string + tag : string, + constArgs : int (* What length prefix of the arguments never vary across recursive calls? *) } type state = { @@ -133,6 +134,92 @@ | CFfi ("Basis", "sql_injectable") => true | _ => false} +fun getApp (e, _) = + case e of + ENamed f => SOME (f, []) + | EApp (e1, e2) => + (case getApp e1 of + NONE => NONE + | SOME (f, xs) => SOME (f, xs @ [e2])) + | _ => NONE + +val getApp = fn e => case getApp e of + v as SOME (_, _ :: _) => v + | _ => NONE + +val maxInt = Option.getOpt (Int.maxInt, 9999) + +fun calcConstArgs enclosingFunction e = + let + fun ca depth e = + case #1 e of + EPrim _ => maxInt + | ERel _ => maxInt + | ENamed n => if n = enclosingFunction then 0 else maxInt + | ECon (_, _, _, NONE) => maxInt + | ECon (_, _, _, SOME e) => ca depth e + | EFfi _ => maxInt + | EFfiApp (_, _, ecs) => foldl (fn ((e, _), d) => Int.min (ca depth e, d)) maxInt ecs + | EApp (e1, e2) => + let + fun default () = Int.min (ca depth e1, ca depth e2) + in + case getApp e of + NONE => default () + | SOME (f, args) => + if f <> enclosingFunction then + default () + else + let + fun visitArgs (count, args) = + case args of + [] => count + | arg :: args' => + let + fun default () = foldl (fn (e, d) => Int.min (ca depth e, d)) count args + in + case #1 arg of + ERel n => + if n = depth - 1 then + visitArgs (count + 1, args') + else + default () + | _ => default () + end + in + visitArgs (0, args) + end + end + | EAbs (_, _, _, e1) => ca (depth + 1) e1 + | ECApp (e1, _) => ca depth e1 + | ECAbs (_, _, e1) => ca depth e1 + | EKAbs (_, e1) => ca depth e1 + | EKApp (e1, _) => ca depth e1 + | ERecord xets => foldl (fn ((_, e, _), d) => Int.min (ca depth e, d)) maxInt xets + | EField (e1, _, _) => ca depth e1 + | EConcat (e1, _, e2, _) => Int.min (ca depth e1, ca depth e2) + | ECut (e1, _, _) => ca depth e1 + | ECutMulti (e1, _, _) => ca depth e1 + | ECase (e1, pes, _) => foldl (fn ((p, e), d) => Int.min (ca (depth + E.patBindsN p) e, d)) (ca depth e1) pes + | EWrite e1 => ca depth e1 + | EClosure (_, es) => foldl (fn (e, d) => Int.min (ca depth e, d)) maxInt es + | ELet (_, _, e1, e2) => Int.min (ca depth e1, ca (depth + 1) e2) + | EServerCall (_, es, _) => foldl (fn (e, d) => Int.min (ca depth e, d)) maxInt es + + fun enterAbs depth e = + case #1 e of + EAbs (_, _, _, e1) => enterAbs (depth + 1) e1 + | _ => ca depth e + + val n = enterAbs 0 e + in + if n = maxInt then + 0 + else + n + end + + fun specialize' (funcs, specialized) file = let fun bind (env, b) = @@ -145,19 +232,6 @@ (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty (e, ErrorMsg.dummySpan))]*) - fun getApp (e, _) = - case e of - ENamed f => SOME (f, []) - | EApp (e1, e2) => - (case getApp e1 of - NONE => NONE - | SOME (f, xs) => SOME (f, xs @ [e2])) - | _ => NONE - - val getApp = fn e => case getApp e of - v as SOME (_, _ :: _) => v - | _ => NONE - fun default () = case #1 e of EPrim _ => (e, st) @@ -290,7 +364,7 @@ | SOME (f, xs) => case IM.find (#funcs st, f) of NONE => ((*print ("No find: " ^ Int.toString f ^ "\n");*) default ()) - | SOME {name, args, body, typ, tag} => + | SOME {name, args, body, typ, tag, constArgs} => let val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs @@ -299,77 +373,32 @@ val loc = ErrorMsg.dummySpan - fun findSplit av (xs, typ, fxs, fvs, fin) = + fun findSplit av (constArgs, xs, typ, fxs, fvs) = case (#1 typ, xs) of (TFun (dom, ran), e :: xs') => - let - val av = case #1 e of - ERel _ => av - | _ => false - in - if functionInside dom orelse (av andalso case #1 e of - ERel _ => true - | _ => false) then - findSplit av (xs', - ran, - e :: fxs, - IS.union (fvs, freeVars e), - fin orelse functionInside dom) - else - (rev fxs, xs, fvs, fin) - end - | _ => (rev fxs, xs, fvs, fin) + if constArgs > 0 then + findSplit av (constArgs - 1, + xs', + ran, + e :: fxs, + IS.union (fvs, freeVars e)) + else + (rev fxs, xs, fvs) + | _ => (rev fxs, xs, fvs) - val (fxs, xs, fvs, fin) = findSplit true (xs, typ, [], IS.empty, false) - - fun valueish (all as (e, _)) = - case e of - EPrim _ => true - | ERel _ => true - | ENamed _ => true - | ECon (_, _, _, NONE) => true - | ECon (_, _, _, SOME e) => valueish e - | EFfi (_, _) => true - | EAbs _ => true - | ECAbs _ => true - | EKAbs _ => true - | ECApp (e, _) => valueish e - | EKApp (e, _) => valueish e - | EApp (e1, e2) => valueish e1 andalso valueish e2 - | ERecord xes => List.all (valueish o #2) xes - | EField (e, _, _) => valueish e - | _ => false + val (fxs, xs, fvs) = findSplit true (constArgs, xs, typ, [], IS.empty) val vts = map (fn n => #2 (List.nth (env, n))) (IS.listItems fvs) val fxs' = map (squish (IS.listItems fvs)) fxs val p_bool = Print.PD.string o Bool.toString - - fun bumpCount n = - if IS.member (#specialized st, f) then - n - else - 5 + n in (*Print.prefaces "Func" [("name", Print.PD.string name), ("e", CorePrint.p_exp CoreEnv.empty e), ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')];*) - if not fin - orelse List.all (fn (ERel _, _) => true - | _ => false) fxs' - orelse List.exists (not o valueish) fxs' - orelse IS.numItems fvs >= bumpCount (length fxs) then - ((*Print.prefaces "No" [("name", Print.PD.string name), - ("f", Print.PD.string (Int.toString f)), - ("fxs'", - Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs'), - ("b1", p_bool (not fin)), - ("b2", p_bool (List.all (fn (ERel _, _) => true - | _ => false) fxs')), - ("b3", p_bool (List.exists (not o valueish) fxs')), - ("b4", p_bool (IS.numItems fvs >= length fxs - andalso IS.exists (fn n => functionInside (#2 (List.nth (env, n)))) fvs))];*) - default ()) + if List.all (fn (ERel _, _) => true + | _ => false) fxs' then + default () else case KM.find (args, (vts, fxs')) of SOME f' => @@ -397,6 +426,12 @@ [("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*) + (*val () = Print.prefaces name + [("Available", Print.PD.string (Int.toString constArgs)), + ("Used", Print.PD.string (Int.toString (length fxs'))), + ("fxs'", + Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*) + fun subBody (body, typ, fxs') = case (#1 body, #1 typ, fxs') of (_, _, []) => SOME (body, typ) @@ -420,7 +455,8 @@ args = args, body = body, typ = typ, - tag = tag}) + tag = tag, + constArgs = calcConstArgs f body}) val st = { maxName = f' + 1, @@ -484,7 +520,8 @@ args = KM.empty, body = e, typ = c, - tag = tag})) + tag = tag, + constArgs = calcConstArgs n e})) funcs vis | _ => funcs @@ -565,7 +602,8 @@ args = KM.empty, body = e, typ = c, - tag = tag}) + tag = tag, + constArgs = calcConstArgs n e}) | DVal (_, n, _, (ENamed n', _), _) => (case IM.find (funcs, n') of NONE => funcs