comparison src/especialize.sml @ 1675:13dad713da35

New, more principled heuristic for Especialize: only specialize uniform function arguments; that is, arguments that don't change across recursive calls
author Adam Chlipala <adam@chlipala.net>
date Wed, 11 Jan 2012 13:53:35 -0500
parents 4cacced4a6da
children 266814b15dd6
comparison
equal deleted inserted replaced
1674:4cacced4a6da 1675:13dad713da35
107 type func = { 107 type func = {
108 name : string, 108 name : string,
109 args : int KM.map, 109 args : int KM.map,
110 body : exp, 110 body : exp,
111 typ : con, 111 typ : con,
112 tag : string 112 tag : string,
113 constArgs : int (* What length prefix of the arguments never vary across recursive calls? *)
113 } 114 }
114 115
115 type state = { 116 type state = {
116 maxName : int, 117 maxName : int,
117 funcs : func IM.map, 118 funcs : func IM.map,
131 | CFfi ("Basis", "read") => true 132 | CFfi ("Basis", "read") => true
132 | CFfi ("Basis", "sql_injectable_prim") => true 133 | CFfi ("Basis", "sql_injectable_prim") => true
133 | CFfi ("Basis", "sql_injectable") => true 134 | CFfi ("Basis", "sql_injectable") => true
134 | _ => false} 135 | _ => false}
135 136
137 fun getApp (e, _) =
138 case e of
139 ENamed f => SOME (f, [])
140 | EApp (e1, e2) =>
141 (case getApp e1 of
142 NONE => NONE
143 | SOME (f, xs) => SOME (f, xs @ [e2]))
144 | _ => NONE
145
146 val getApp = fn e => case getApp e of
147 v as SOME (_, _ :: _) => v
148 | _ => NONE
149
150 val maxInt = Option.getOpt (Int.maxInt, 9999)
151
152 fun calcConstArgs enclosingFunction e =
153 let
154 fun ca depth e =
155 case #1 e of
156 EPrim _ => maxInt
157 | ERel _ => maxInt
158 | ENamed n => if n = enclosingFunction then 0 else maxInt
159 | ECon (_, _, _, NONE) => maxInt
160 | ECon (_, _, _, SOME e) => ca depth e
161 | EFfi _ => maxInt
162 | EFfiApp (_, _, ecs) => foldl (fn ((e, _), d) => Int.min (ca depth e, d)) maxInt ecs
163 | EApp (e1, e2) =>
164 let
165 fun default () = Int.min (ca depth e1, ca depth e2)
166 in
167 case getApp e of
168 NONE => default ()
169 | SOME (f, args) =>
170 if f <> enclosingFunction then
171 default ()
172 else
173 let
174 fun visitArgs (count, args) =
175 case args of
176 [] => count
177 | arg :: args' =>
178 let
179 fun default () = foldl (fn (e, d) => Int.min (ca depth e, d)) count args
180 in
181 case #1 arg of
182 ERel n =>
183 if n = depth - 1 then
184 visitArgs (count + 1, args')
185 else
186 default ()
187 | _ => default ()
188 end
189 in
190 visitArgs (0, args)
191 end
192 end
193 | EAbs (_, _, _, e1) => ca (depth + 1) e1
194 | ECApp (e1, _) => ca depth e1
195 | ECAbs (_, _, e1) => ca depth e1
196 | EKAbs (_, e1) => ca depth e1
197 | EKApp (e1, _) => ca depth e1
198 | ERecord xets => foldl (fn ((_, e, _), d) => Int.min (ca depth e, d)) maxInt xets
199 | EField (e1, _, _) => ca depth e1
200 | EConcat (e1, _, e2, _) => Int.min (ca depth e1, ca depth e2)
201 | ECut (e1, _, _) => ca depth e1
202 | ECutMulti (e1, _, _) => ca depth e1
203 | ECase (e1, pes, _) => foldl (fn ((p, e), d) => Int.min (ca (depth + E.patBindsN p) e, d)) (ca depth e1) pes
204 | EWrite e1 => ca depth e1
205 | EClosure (_, es) => foldl (fn (e, d) => Int.min (ca depth e, d)) maxInt es
206 | ELet (_, _, e1, e2) => Int.min (ca depth e1, ca (depth + 1) e2)
207 | EServerCall (_, es, _) => foldl (fn (e, d) => Int.min (ca depth e, d)) maxInt es
208
209 fun enterAbs depth e =
210 case #1 e of
211 EAbs (_, _, _, e1) => enterAbs (depth + 1) e1
212 | _ => ca depth e
213
214 val n = enterAbs 0 e
215 in
216 if n = maxInt then
217 0
218 else
219 n
220 end
221
222
136 fun specialize' (funcs, specialized) file = 223 fun specialize' (funcs, specialized) file =
137 let 224 let
138 fun bind (env, b) = 225 fun bind (env, b) =
139 case b of 226 case b of
140 U.Decl.RelE xt => xt :: env 227 U.Decl.RelE xt => xt :: env
142 229
143 fun exp (env, e as (_, loc), st : state) = 230 fun exp (env, e as (_, loc), st : state) =
144 let 231 let
145 (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty 232 (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty
146 (e, ErrorMsg.dummySpan))]*) 233 (e, ErrorMsg.dummySpan))]*)
147
148 fun getApp (e, _) =
149 case e of
150 ENamed f => SOME (f, [])
151 | EApp (e1, e2) =>
152 (case getApp e1 of
153 NONE => NONE
154 | SOME (f, xs) => SOME (f, xs @ [e2]))
155 | _ => NONE
156
157 val getApp = fn e => case getApp e of
158 v as SOME (_, _ :: _) => v
159 | _ => NONE
160 234
161 fun default () = 235 fun default () =
162 case #1 e of 236 case #1 e of
163 EPrim _ => (e, st) 237 EPrim _ => (e, st)
164 | ERel _ => (e, st) 238 | ERel _ => (e, st)
288 case getApp e of 362 case getApp e of
289 NONE => default () 363 NONE => default ()
290 | SOME (f, xs) => 364 | SOME (f, xs) =>
291 case IM.find (#funcs st, f) of 365 case IM.find (#funcs st, f) of
292 NONE => ((*print ("No find: " ^ Int.toString f ^ "\n");*) default ()) 366 NONE => ((*print ("No find: " ^ Int.toString f ^ "\n");*) default ())
293 | SOME {name, args, body, typ, tag} => 367 | SOME {name, args, body, typ, tag, constArgs} =>
294 let 368 let
295 val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs 369 val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs
296 370
297 (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty 371 (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty
298 (e, ErrorMsg.dummySpan))]*) 372 (e, ErrorMsg.dummySpan))]*)
299 373
300 val loc = ErrorMsg.dummySpan 374 val loc = ErrorMsg.dummySpan
301 375
302 fun findSplit av (xs, typ, fxs, fvs, fin) = 376 fun findSplit av (constArgs, xs, typ, fxs, fvs) =
303 case (#1 typ, xs) of 377 case (#1 typ, xs) of
304 (TFun (dom, ran), e :: xs') => 378 (TFun (dom, ran), e :: xs') =>
305 let 379 if constArgs > 0 then
306 val av = case #1 e of 380 findSplit av (constArgs - 1,
307 ERel _ => av 381 xs',
308 | _ => false 382 ran,
309 in 383 e :: fxs,
310 if functionInside dom orelse (av andalso case #1 e of 384 IS.union (fvs, freeVars e))
311 ERel _ => true 385 else
312 | _ => false) then 386 (rev fxs, xs, fvs)
313 findSplit av (xs', 387 | _ => (rev fxs, xs, fvs)
314 ran, 388
315 e :: fxs, 389 val (fxs, xs, fvs) = findSplit true (constArgs, xs, typ, [], IS.empty)
316 IS.union (fvs, freeVars e),
317 fin orelse functionInside dom)
318 else
319 (rev fxs, xs, fvs, fin)
320 end
321 | _ => (rev fxs, xs, fvs, fin)
322
323 val (fxs, xs, fvs, fin) = findSplit true (xs, typ, [], IS.empty, false)
324
325 fun valueish (all as (e, _)) =
326 case e of
327 EPrim _ => true
328 | ERel _ => true
329 | ENamed _ => true
330 | ECon (_, _, _, NONE) => true
331 | ECon (_, _, _, SOME e) => valueish e
332 | EFfi (_, _) => true
333 | EAbs _ => true
334 | ECAbs _ => true
335 | EKAbs _ => true
336 | ECApp (e, _) => valueish e
337 | EKApp (e, _) => valueish e
338 | EApp (e1, e2) => valueish e1 andalso valueish e2
339 | ERecord xes => List.all (valueish o #2) xes
340 | EField (e, _, _) => valueish e
341 | _ => false
342 390
343 val vts = map (fn n => #2 (List.nth (env, n))) (IS.listItems fvs) 391 val vts = map (fn n => #2 (List.nth (env, n))) (IS.listItems fvs)
344 val fxs' = map (squish (IS.listItems fvs)) fxs 392 val fxs' = map (squish (IS.listItems fvs)) fxs
345 393
346 val p_bool = Print.PD.string o Bool.toString 394 val p_bool = Print.PD.string o Bool.toString
347
348 fun bumpCount n =
349 if IS.member (#specialized st, f) then
350 n
351 else
352 5 + n
353 in 395 in
354 (*Print.prefaces "Func" [("name", Print.PD.string name), 396 (*Print.prefaces "Func" [("name", Print.PD.string name),
355 ("e", CorePrint.p_exp CoreEnv.empty e), 397 ("e", CorePrint.p_exp CoreEnv.empty e),
356 ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')];*) 398 ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')];*)
357 if not fin 399 if List.all (fn (ERel _, _) => true
358 orelse List.all (fn (ERel _, _) => true 400 | _ => false) fxs' then
359 | _ => false) fxs' 401 default ()
360 orelse List.exists (not o valueish) fxs'
361 orelse IS.numItems fvs >= bumpCount (length fxs) then
362 ((*Print.prefaces "No" [("name", Print.PD.string name),
363 ("f", Print.PD.string (Int.toString f)),
364 ("fxs'",
365 Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs'),
366 ("b1", p_bool (not fin)),
367 ("b2", p_bool (List.all (fn (ERel _, _) => true
368 | _ => false) fxs')),
369 ("b3", p_bool (List.exists (not o valueish) fxs')),
370 ("b4", p_bool (IS.numItems fvs >= length fxs
371 andalso IS.exists (fn n => functionInside (#2 (List.nth (env, n)))) fvs))];*)
372 default ())
373 else 402 else
374 case KM.find (args, (vts, fxs')) of 403 case KM.find (args, (vts, fxs')) of
375 SOME f' => 404 SOME f' =>
376 let 405 let
377 val e = (ENamed f', loc) 406 val e = (ENamed f', loc)
393 ("|fxs|", Print.PD.string (Int.toString (length fxs))), 422 ("|fxs|", Print.PD.string (Int.toString (length fxs))),
394 ("spec", Print.PD.string (Bool.toString (IS.member (#specialized st, f))))]*) 423 ("spec", Print.PD.string (Bool.toString (IS.member (#specialized st, f))))]*)
395 424
396 (*val () = Print.prefaces ("Yes(" ^ name ^ ")") 425 (*val () = Print.prefaces ("Yes(" ^ name ^ ")")
397 [("fxs'", 426 [("fxs'",
427 Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*)
428
429 (*val () = Print.prefaces name
430 [("Available", Print.PD.string (Int.toString constArgs)),
431 ("Used", Print.PD.string (Int.toString (length fxs'))),
432 ("fxs'",
398 Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*) 433 Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*)
399 434
400 fun subBody (body, typ, fxs') = 435 fun subBody (body, typ, fxs') =
401 case (#1 body, #1 typ, fxs') of 436 case (#1 body, #1 typ, fxs') of
402 (_, _, []) => SOME (body, typ) 437 (_, _, []) => SOME (body, typ)
418 val args = KM.insert (args, (vts, fxs'), f') 453 val args = KM.insert (args, (vts, fxs'), f')
419 val funcs = IM.insert (#funcs st, f, {name = name, 454 val funcs = IM.insert (#funcs st, f, {name = name,
420 args = args, 455 args = args,
421 body = body, 456 body = body,
422 typ = typ, 457 typ = typ,
423 tag = tag}) 458 tag = tag,
459 constArgs = calcConstArgs f body})
424 460
425 val st = { 461 val st = {
426 maxName = f' + 1, 462 maxName = f' + 1,
427 funcs = funcs, 463 funcs = funcs,
428 decls = #decls st, 464 decls = #decls st,
482 foldl (fn ((x, n, c, e, tag), funcs) => 518 foldl (fn ((x, n, c, e, tag), funcs) =>
483 IM.insert (funcs, n, {name = x, 519 IM.insert (funcs, n, {name = x,
484 args = KM.empty, 520 args = KM.empty,
485 body = e, 521 body = e,
486 typ = c, 522 typ = c,
487 tag = tag})) 523 tag = tag,
524 constArgs = calcConstArgs n e}))
488 funcs vis 525 funcs vis
489 | _ => funcs 526 | _ => funcs
490 527
491 val st = {maxName = #maxName st, 528 val st = {maxName = #maxName st,
492 funcs = funcs, 529 funcs = funcs,
563 DVal (x, n, c, e as (EAbs _, _), tag) => 600 DVal (x, n, c, e as (EAbs _, _), tag) =>
564 IM.insert (funcs, n, {name = x, 601 IM.insert (funcs, n, {name = x,
565 args = KM.empty, 602 args = KM.empty,
566 body = e, 603 body = e,
567 typ = c, 604 typ = c,
568 tag = tag}) 605 tag = tag,
606 constArgs = calcConstArgs n e})
569 | DVal (_, n, _, (ENamed n', _), _) => 607 | DVal (_, n, _, (ENamed n', _), _) =>
570 (case IM.find (funcs, n') of 608 (case IM.find (funcs, n') of
571 NONE => funcs 609 NONE => funcs
572 | SOME v => IM.insert (funcs, n, v)) 610 | SOME v => IM.insert (funcs, n, v))
573 | _ => funcs 611 | _ => funcs