changeset 1675:13dad713da35

New, more principled heuristic for Especialize: only specialize uniform function arguments; that is, arguments that don't change across recursive calls
author Adam Chlipala <adam@chlipala.net>
date Wed, 11 Jan 2012 13:53:35 -0500 (2012-01-11)
parents 4cacced4a6da
children 266814b15dd6
files src/especialize.sml
diffstat 1 files changed, 115 insertions(+), 77 deletions(-) [+]
line wrap: on
line diff
--- a/src/especialize.sml	Wed Jan 11 11:08:48 2012 -0500
+++ b/src/especialize.sml	Wed Jan 11 13:53:35 2012 -0500
@@ -109,7 +109,8 @@
      args : int KM.map,
      body : exp,
      typ : con,
-     tag : string
+     tag : string,
+     constArgs : int (* What length prefix of the arguments never vary across recursive calls? *)
 }
 
 type state = {
@@ -133,6 +134,92 @@
                                           | CFfi ("Basis", "sql_injectable") => true
                                           | _ => false}
 
+fun getApp (e, _) =
+    case e of
+        ENamed f => SOME (f, [])
+      | EApp (e1, e2) =>
+        (case getApp e1 of
+             NONE => NONE
+           | SOME (f, xs) => SOME (f, xs @ [e2]))
+      | _ => NONE
+
+val getApp = fn e => case getApp e of
+                         v as SOME (_, _ :: _) => v
+                       | _ => NONE
+
+val maxInt = Option.getOpt (Int.maxInt, 9999)
+
+fun calcConstArgs enclosingFunction e =
+    let
+        fun ca depth e =
+            case #1 e of
+                EPrim _ => maxInt
+              | ERel _ => maxInt
+              | ENamed n => if n = enclosingFunction then 0 else maxInt
+              | ECon (_, _, _, NONE) => maxInt
+              | ECon (_, _, _, SOME e) => ca depth e
+              | EFfi _ => maxInt
+              | EFfiApp (_, _, ecs) => foldl (fn ((e, _), d) => Int.min (ca depth e, d)) maxInt ecs
+              | EApp (e1, e2) =>
+                let
+                    fun default () = Int.min (ca depth e1, ca depth e2)
+                in
+                    case getApp e of
+                        NONE => default ()
+                      | SOME (f, args) =>
+                        if f <> enclosingFunction then
+                            default ()
+                        else
+                            let
+                                fun visitArgs (count, args) =
+                                    case args of
+                                        [] => count
+                                      | arg :: args' =>
+                                        let
+                                            fun default () = foldl (fn (e, d) => Int.min (ca depth e, d)) count args
+                                        in
+                                            case #1 arg of
+                                                ERel n =>
+                                                if n = depth - 1 then
+                                                    visitArgs (count + 1, args')
+                                                else
+                                                    default ()
+                                              | _ => default ()
+                                        end
+                            in
+                                visitArgs (0, args)
+                            end
+                end
+              | EAbs (_, _, _, e1) => ca (depth + 1) e1
+              | ECApp (e1, _) => ca depth e1
+              | ECAbs (_, _, e1) => ca depth e1
+              | EKAbs (_, e1) => ca depth e1
+              | EKApp (e1, _) => ca depth e1
+              | ERecord xets => foldl (fn ((_, e, _), d) => Int.min (ca depth e, d)) maxInt xets
+              | EField (e1, _, _) => ca depth e1
+              | EConcat (e1, _, e2, _) => Int.min (ca depth e1, ca depth e2)
+              | ECut (e1, _, _) => ca depth e1
+              | ECutMulti (e1, _, _) => ca depth e1
+              | ECase (e1, pes, _) => foldl (fn ((p, e), d) => Int.min (ca (depth + E.patBindsN p) e, d)) (ca depth e1) pes
+              | EWrite e1 => ca depth e1
+              | EClosure (_, es) => foldl (fn (e, d) => Int.min (ca depth e, d)) maxInt es
+              | ELet (_, _, e1, e2) => Int.min (ca depth e1, ca (depth + 1) e2)
+              | EServerCall (_, es, _) => foldl (fn (e, d) => Int.min (ca depth e, d)) maxInt es
+
+        fun enterAbs depth e =
+            case #1 e of
+                EAbs (_, _, _, e1) => enterAbs (depth + 1) e1
+              | _ => ca depth e
+
+        val n = enterAbs 0 e
+    in
+        if n = maxInt then
+            0
+        else
+            n
+    end
+
+
 fun specialize' (funcs, specialized) file =
     let
         fun bind (env, b) =
@@ -145,19 +232,6 @@
                 (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty
                                                                      (e, ErrorMsg.dummySpan))]*)
 
-                fun getApp (e, _) =
-                    case e of
-                        ENamed f => SOME (f, [])
-                      | EApp (e1, e2) =>
-                        (case getApp e1 of
-                             NONE => NONE
-                           | SOME (f, xs) => SOME (f, xs @ [e2]))
-                      | _ => NONE
-
-                val getApp = fn e => case getApp e of
-                                         v as SOME (_, _ :: _) => v
-                                       | _ => NONE
-
                 fun default () =
                     case #1 e of
                         EPrim _ => (e, st)
@@ -290,7 +364,7 @@
                   | SOME (f, xs) =>
                     case IM.find (#funcs st, f) of
                         NONE => ((*print ("No find: " ^ Int.toString f ^ "\n");*) default ())
-                      | SOME {name, args, body, typ, tag} =>
+                      | SOME {name, args, body, typ, tag, constArgs} =>
                         let
                             val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs
 
@@ -299,77 +373,32 @@
 
                             val loc = ErrorMsg.dummySpan
 
-                            fun findSplit av (xs, typ, fxs, fvs, fin) =
+                            fun findSplit av (constArgs, xs, typ, fxs, fvs) =
                                 case (#1 typ, xs) of
                                     (TFun (dom, ran), e :: xs') =>
-                                    let
-                                        val av = case #1 e of
-                                                     ERel _ => av
-                                                   | _ => false
-                                    in
-                                        if functionInside dom orelse (av andalso case #1 e of
-                                                                                     ERel _ => true
-                                                                                   | _ => false) then
-                                            findSplit av (xs',
-                                                          ran,
-                                                          e :: fxs,
-                                                          IS.union (fvs, freeVars e),
-                                                          fin orelse functionInside dom)
-                                        else
-                                            (rev fxs, xs, fvs, fin)
-                                    end
-                                  | _ => (rev fxs, xs, fvs, fin)
+                                    if constArgs > 0 then
+                                        findSplit av (constArgs - 1,
+                                                      xs',
+                                                      ran,
+                                                      e :: fxs,
+                                                      IS.union (fvs, freeVars e))
+                                    else
+                                        (rev fxs, xs, fvs)
+                                  | _ => (rev fxs, xs, fvs)
 
-                            val (fxs, xs, fvs, fin) = findSplit true (xs, typ, [], IS.empty, false)
-
-                            fun valueish (all as (e, _)) =
-                                case e of
-                                    EPrim _ => true
-                                  | ERel _ => true
-                                  | ENamed _ => true
-                                  | ECon (_, _, _, NONE) => true
-                                  | ECon (_, _, _, SOME e) => valueish e
-                                  | EFfi (_, _) => true
-                                  | EAbs _ => true
-                                  | ECAbs _ => true
-                                  | EKAbs _ => true
-                                  | ECApp (e, _) => valueish e
-                                  | EKApp (e, _) => valueish e
-                                  | EApp (e1, e2) => valueish e1 andalso valueish e2
-                                  | ERecord xes => List.all (valueish o #2) xes
-                                  | EField (e, _, _) => valueish e
-                                  | _ => false
+                            val (fxs, xs, fvs) = findSplit true (constArgs, xs, typ, [], IS.empty)
 
                             val vts = map (fn n => #2 (List.nth (env, n))) (IS.listItems fvs)
                             val fxs' = map (squish (IS.listItems fvs)) fxs
 
                             val p_bool = Print.PD.string o Bool.toString
-
-                            fun bumpCount n =
-                                if IS.member (#specialized st, f) then
-                                    n
-                                else
-                                    5 + n
                         in
                             (*Print.prefaces "Func" [("name", Print.PD.string name),
                                                    ("e", CorePrint.p_exp CoreEnv.empty e),
                                                    ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')];*)
-                            if not fin
-                               orelse List.all (fn (ERel _, _) => true
-                                                 | _ => false) fxs'
-                               orelse List.exists (not o valueish) fxs'
-                               orelse IS.numItems fvs >= bumpCount (length fxs) then
-                                ((*Print.prefaces "No" [("name", Print.PD.string name),
-                                                      ("f", Print.PD.string (Int.toString f)),
-                                                      ("fxs'",
-                                                       Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs'),
-                                                      ("b1", p_bool (not fin)),
-                                                      ("b2", p_bool (List.all (fn (ERel _, _) => true
-                                                                                | _ => false) fxs')),
-                                                      ("b3", p_bool (List.exists (not o valueish) fxs')),
-                                                      ("b4", p_bool (IS.numItems fvs >= length fxs
-                                                                     andalso IS.exists (fn n => functionInside (#2 (List.nth (env, n)))) fvs))];*)
-                                 default ())
+                            if List.all (fn (ERel _, _) => true
+                                          | _ => false) fxs' then
+                                default ()
                             else
                                 case KM.find (args, (vts, fxs')) of
                                     SOME f' =>
@@ -397,6 +426,12 @@
                                                                 [("fxs'",
                                                                   Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*)
 
+                                        (*val () = Print.prefaces name
+                                                                [("Available", Print.PD.string (Int.toString constArgs)),
+                                                                 ("Used", Print.PD.string (Int.toString (length fxs'))),
+                                                                 ("fxs'",
+                                                                  Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*)
+
                                         fun subBody (body, typ, fxs') =
                                             case (#1 body, #1 typ, fxs') of
                                                 (_, _, []) => SOME (body, typ)
@@ -420,7 +455,8 @@
                                                                                       args = args,
                                                                                       body = body,
                                                                                       typ = typ,
-                                                                                      tag = tag})
+                                                                                      tag = tag,
+                                                                                      constArgs = calcConstArgs f body})
 
                                                 val st = {
                                                     maxName = f' + 1,
@@ -484,7 +520,8 @@
                                                         args = KM.empty,
                                                         body = e,
                                                         typ = c,
-                                                        tag = tag}))
+                                                        tag = tag,
+                                                        constArgs = calcConstArgs n e}))
                               funcs vis
                       | _ => funcs
 
@@ -565,7 +602,8 @@
                                               args = KM.empty,
                                               body = e,
                                               typ = c,
-                                              tag = tag})
+                                              tag = tag,
+                                              constArgs = calcConstArgs n e})
                       | DVal (_, n, _, (ENamed n', _), _) =>
                         (case IM.find (funcs, n') of
                              NONE => funcs