diff src/especialize.sml @ 1079:d069b193ed6b

Especialize uses a termination measure based on number of arguments introduced
author Adam Chlipala <adamc@hcoop.net>
date Tue, 15 Dec 2009 19:26:52 -0500
parents b9321bcefb42
children a4979e31e4bf
line wrap: on
line diff
--- a/src/especialize.sml	Tue Dec 15 13:20:13 2009 -0500
+++ b/src/especialize.sml	Tue Dec 15 19:26:52 2009 -0500
@@ -79,14 +79,14 @@
         pof (0, ls)
     end
 
-fun squish (untouched, fvs) =
+fun squish fvs =
     U.Exp.mapB {kind = fn _ => fn k => k,
                 con = fn _ => fn c => c,
                 exp = fn bound => fn e =>
                                      case e of
                                          ERel x =>
                                          if x >= bound then
-                                             ERel (positionOf (x - bound, fvs) + bound + untouched)
+                                             ERel (positionOf (x - bound, fvs) + bound)
                                          else
                                              e
                                        | _ => e,
@@ -107,7 +107,8 @@
 type state = {
      maxName : int,
      funcs : func IM.map,
-     decls : (string * int * con * exp * string) list
+     decls : (string * int * con * exp * string) list,
+     specialized : bool IM.map
 }
 
 fun default (_, x, st) = (x, st)
@@ -119,7 +120,7 @@
 
 val mayNotSpec = ref SS.empty
 
-fun specialize' file =
+fun specialize' specialized file =
     let
         fun bind (env, b) =
             case b of
@@ -165,51 +166,45 @@
                                                                       | _ => false}
                             val loc = ErrorMsg.dummySpan
 
-                            fun hasFuncArg t =
-                                case #1 t of
-                                    TFun (dom, ran) => functionInside dom orelse hasFuncArg ran
-                                  | _ => false
-
-                            fun findSplit hfa (xs, typ, fxs, fvs, ts) =
+                            fun findSplit av (xs, typ, fxs, fvs) =
                                 case (#1 typ, xs) of
                                     (TFun (dom, ran), e :: xs') =>
                                     let
-                                        val isVar = case #1 e of
-                                                        ERel _ => true
-                                                      | _ => false
-                                        val hfa = hfa andalso isVar
+                                        val av = case #1 e of
+                                                     ERel _ => av
+                                                   | _ => false
                                     in
-                                        if hfa orelse functionInside dom then
-                                            findSplit hfa (xs',
-                                                           ran,
-                                                           (true, e) :: fxs,
-                                                           IS.union (fvs, freeVars e),
-                                                           ts)
+                                        if functionInside dom orelse (av andalso case #1 e of
+                                                                                     ERel _ => true
+                                                                                   | _ => false) then
+                                            findSplit av (xs',
+                                                          ran,
+                                                          e :: fxs,
+                                                          IS.union (fvs, freeVars e))
                                         else
-                                            findSplit hfa (xs', ran, (false, e) :: fxs, fvs, dom :: ts)
+                                            (rev fxs, xs, fvs)
                                     end
-                                  | _ => (List.revAppend (fxs, map (fn e => (false, e)) xs), fvs, rev ts)
+                                  | _ => (rev fxs, xs, fvs)
 
-                            val (xs, fvs, ts) = findSplit (hasFuncArg typ) (xs, typ, [], IS.empty, [])
-                            val fxs = List.mapPartial (fn (true, e) => SOME e | _ => NONE) xs
-                            val untouched = length (List.filter (fn (false, _) => true | _ => false) xs)
-                            val squish = squish (untouched, IS.listItems fvs)
-                            val fxs' = map squish fxs
+                            val (fxs, xs, fvs) = findSplit true (xs, typ, [], IS.empty)
+
+                            val fxs' = map (squish (IS.listItems fvs)) fxs
                         in
                             (*Print.preface ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs');*)
-                            if List.all (fn (false, _) => true
-                                          | (true, (ERel _, _)) => true
-                                          | _ => false) xs then
+                            if List.all (fn (ERel _, _) => true
+                                          | _ => false) fxs'
+                               orelse (IS.numItems fvs >= length fxs
+                                       andalso IS.exists (fn n => functionInside (#2 (List.nth (env, n)))) fvs) then
                                 (e, st)
                             else
-                                case (KM.find (args, fxs'), SS.member (!mayNotSpec, name)) of
+                                case (KM.find (args, fxs'),
+                                      SS.member (!mayNotSpec, name) orelse IM.find (#specialized st, f) = SOME true) of
                                     (SOME f', _) =>
                                     let
                                         val e = (ENamed f', loc)
                                         val e = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
                                                          e fvs
-                                        val e = foldl (fn ((false, arg), e) => (EApp (e, arg), loc)
-                                                        | (_, e) => e)
+                                        val e = foldl (fn (arg, e) => (EApp (e, arg), loc))
                                                       e xs
                                     in
                                         (*Print.prefaces "Brand new (reuse)"
@@ -231,24 +226,20 @@
                                                                 [("fxs'",
                                                                   Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*)
 
-                                        fun subBody (body, typ, xs) =
-                                            case (#1 body, #1 typ, xs) of
+                                        fun subBody (body, typ, fxs') =
+                                            case (#1 body, #1 typ, fxs') of
                                                 (_, _, []) => SOME (body, typ)
-                                              | (EAbs (_, _, _, body'), TFun (_, typ'), (b, x) :: xs) =>
+                                              | (EAbs (_, _, _, body'), TFun (_, typ'), x :: fxs'') =>
                                                 let
-                                                    val body'' =
-                                                        if b then
-                                                            E.subExpInExp (0, squish x) body'
-                                                        else
-                                                            body'
+                                                    val body'' = E.subExpInExp (0, x) body'
                                                 in
                                                     subBody (body'',
                                                              typ',
-                                                             xs)
+                                                             fxs'')
                                                 end
                                               | _ => NONE
                                     in
-                                        case subBody (body, typ, xs) of
+                                        case subBody (body, typ, fxs') of
                                             NONE => (e, st)
                                           | SOME (body', typ') =>
                                             let
@@ -259,10 +250,17 @@
                                                                                       body = body,
                                                                                       typ = typ,
                                                                                       tag = tag})
+
+                                                val specialized = IM.insert (#specialized st, f', false)
+                                                val specialized = case IM.find (specialized, f) of
+                                                                      NONE => specialized
+                                                                    | SOME _ => IM.insert (specialized, f, true)
+
                                                 val st = {
                                                     maxName = f' + 1,
                                                     funcs = funcs,
-                                                    decls = #decls st
+                                                    decls = #decls st,
+                                                    specialized = specialized
                                                 }
 
                                                 (*val () = Print.prefaces "specExp"
@@ -272,12 +270,6 @@
                                                                          ("fxs'", Print.p_list
                                                                                       (CorePrint.p_exp E.empty) fxs'),
                                                                          ("e", CorePrint.p_exp env (e, loc))]*)
-
-                                                val (body', typ') = foldr (fn (t, (body', typ')) =>
-                                                                              ((EAbs ("x", t, typ', body'), loc),
-                                                                               (TFun (t, typ'), loc)))
-                                                                          (body', typ') ts
-
                                                 val (body', typ') = IS.foldl (fn (n, (body', typ')) =>
                                                                                  let
                                                                                      val (x, xt) = List.nth (env, n)
@@ -296,8 +288,7 @@
                                                 val e' = (ENamed f', loc)
                                                 val e' = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
                                                                   e' fvs
-                                                val e' = foldl (fn ((false, arg), e) => (EApp (e, arg), loc)
-                                                                 | (_, e) => e)
+                                                val e' = foldl (fn (arg, e) => (EApp (e, arg), loc))
                                                                e' xs
                                                 (*val () = Print.prefaces "Brand new"
                                                                         [("e'", CorePrint.p_exp CoreEnv.empty e'),
@@ -307,7 +298,8 @@
                                                 (#1 e',
                                                  {maxName = #maxName st,
                                                   funcs = #funcs st,
-                                                  decls = (name, f', typ', body', tag) :: #decls st})
+                                                  decls = (name, f', typ', body', tag) :: #decls st,
+                                                  specialized = #specialized st})
                                             end
                                     end
                         end
@@ -336,7 +328,8 @@
 
                 val st = {maxName = #maxName st,
                           funcs = funcs,
-                          decls = []}
+                          decls = [],
+                          specialized = #specialized st}
 
                 (*val () = Print.prefaces "decl" [("d", CorePrint.p_decl CoreEnv.empty d)]*)
 
@@ -381,25 +374,27 @@
                                          ("d'", CorePrint.p_decl E.empty d')];*)
                 (ds, ({maxName = #maxName st,
                        funcs = funcs,
-                       decls = []}, changed))
+                       decls = [],
+                       specialized = #specialized st}, changed))
             end
 
-        val (ds, (_, changed)) = ListUtil.foldlMapConcat doDecl
+        val (ds, (st, changed)) = ListUtil.foldlMapConcat doDecl
                                                             ({maxName = U.File.maxName file + 1,
                                                               funcs = IM.empty,
-                                                              decls = []},
+                                                              decls = [],
+                                                              specialized = specialized},
                                                              false)
                                                             file
     in
-        (changed, ds)
+        (changed, ds, #specialized st)
     end
 
-fun specialize file =
+fun specializeL specialized file =
     let
         val file = ReduceLocal.reduce file
         (*val () = Print.prefaces "Intermediate" [("file", CorePrint.p_file CoreEnv.empty file)]*)
         (*val file = ReduceLocal.reduce file*)
-        val (changed, file) = specialize' file
+        val (changed, file, specialized) = specialize' specialized file
         (*val file = ReduceLocal.reduce file
         val file = CoreUntangle.untangle file
         val file = Shake.shake file*)
@@ -414,10 +409,12 @@
                 val file = Shake.shake file
             in
                 (*print "Again!\n";*)
-                specialize file
+                specializeL specialized file
             end
         else
             file
     end
 
+val specialize = specializeL IM.empty
+
 end