comparison src/especialize.sml @ 1861:52043ad66ce7

Extend Especialize rule: find maximal argument prefixes that end in 1 or more arguments with functional types
author Adam Chlipala <adam@chlipala.net>
date Fri, 09 Aug 2013 16:04:16 -0400
parents e15234fbb163
children 32784d27b5bc
comparison
equal deleted inserted replaced
1860:d54984564bcd 1861:52043ad66ce7
362 NONE => ((*print ("No find: " ^ Int.toString f ^ "\n");*) default ()) 362 NONE => ((*print ("No find: " ^ Int.toString f ^ "\n");*) default ())
363 | SOME {name, args, body, typ, tag, constArgs} => 363 | SOME {name, args, body, typ, tag, constArgs} =>
364 let 364 let
365 val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs 365 val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs
366 366
367 (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty 367 (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty e)]*)
368 (e, ErrorMsg.dummySpan))]*)
369 368
370 val loc = ErrorMsg.dummySpan 369 val loc = ErrorMsg.dummySpan
371 370
372 val oldXs = xs 371 val oldXs = xs
373 372
374 fun findSplit av (constArgs, xs, typ, fxs, fvs) = 373 fun findSplit av (initialPart, constArgs, xs, typ, fxs, fvs) =
375 case (#1 typ, xs) of 374 let
376 (TFun (dom, ran), e :: xs') => 375 fun default () =
377 if constArgs > 0 then 376 if initialPart then
378 if functionInside dom then 377 ([], oldXs, IS.empty)
379 (rev (e :: fxs), xs', IS.union (fvs, freeVars e))
380 else 378 else
381 findSplit av (constArgs - 1, 379 (rev fxs, xs, fvs)
382 xs', 380 in
383 ran, 381 case (#1 typ, xs) of
384 e :: fxs, 382 (TFun (dom, ran), e :: xs') =>
385 IS.union (fvs, freeVars e)) 383 if constArgs > 0 then
386 else 384 let
387 ([], oldXs, IS.empty) 385 val fi = functionInside dom
388 | _ => ([], oldXs, IS.empty) 386 in
389 387 if initialPart orelse fi then
390 val (fxs, xs, fvs) = findSplit true (constArgs, xs, typ, [], IS.empty) 388 findSplit av (not fi andalso initialPart,
389 constArgs - 1,
390 xs',
391 ran,
392 e :: fxs,
393 IS.union (fvs, freeVars e))
394 else
395 default ()
396 end
397 else
398 default ()
399 | _ => default ()
400 end
401
402 val (fxs, xs, fvs) = findSplit true (true, constArgs, xs, typ, [], IS.empty)
391 403
392 val vts = map (fn n => #2 (List.nth (env, n))) (IS.listItems fvs) 404 val vts = map (fn n => #2 (List.nth (env, n))) (IS.listItems fvs)
393 val fxs' = map (squish (IS.listItems fvs)) fxs 405 val fxs' = map (squish (IS.listItems fvs)) fxs
394 406
395 val p_bool = Print.PD.string o Bool.toString 407 val p_bool = Print.PD.string o Bool.toString
481 ((EAbs (x, xt, typ', body'), 493 ((EAbs (x, xt, typ', body'),
482 loc), 494 loc),
483 (TFun (xt, typ'), loc)) 495 (TFun (xt, typ'), loc))
484 end) 496 end)
485 (body', typ') fvs 497 (body', typ') fvs
486 (*val () = print ("NEW: " ^ name ^ "__" ^ Int.toString f' ^ "\n");*) 498 (*val () = print ("NEW: " ^ name ^ "__" ^ Int.toString f' ^ "\n")*)
487 val body' = ReduceLocal.reduceExp body' 499 val body' = ReduceLocal.reduceExp body'
488 (*val () = Print.preface ("PRE", CorePrint.p_exp CoreEnv.empty body')*) 500 (*val () = Print.preface ("PRE", CorePrint.p_exp CoreEnv.empty body')*)
489 val (body', st) = exp (env, body', st) 501 val (body', st) = exp (env, body', st)
490 502
491 val e' = (ENamed f', loc) 503 val e' = (ENamed f', loc)
521 val fs = foldl (fn ((_, n, _, _, _), fs) => IS.add (fs, n)) IS.empty vis 533 val fs = foldl (fn ((_, n, _, _, _), fs) => IS.add (fs, n)) IS.empty vis
522 val constArgs = foldl (fn ((_, _, _, e, _), constArgs) => 534 val constArgs = foldl (fn ((_, _, _, e, _), constArgs) =>
523 Int.min (constArgs, calcConstArgs fs e)) 535 Int.min (constArgs, calcConstArgs fs e))
524 maxInt vis 536 maxInt vis
525 in 537 in
538 (*Print.prefaces "ConstArgs" [("d", CorePrint.p_decl CoreEnv.empty d),
539 ("ca", Print.PD.string (Int.toString constArgs))];*)
526 foldl (fn ((x, n, c, e, tag), funcs) => 540 foldl (fn ((x, n, c, e, tag), funcs) =>
527 IM.insert (funcs, n, {name = x, 541 IM.insert (funcs, n, {name = x,
528 args = KM.empty, 542 args = KM.empty,
529 body = e, 543 body = e,
530 typ = c, 544 typ = c,
605 619
606 val funcs = #funcs st 620 val funcs = #funcs st
607 val funcs = 621 val funcs =
608 case #1 d of 622 case #1 d of
609 DVal (x, n, c, e as (EAbs _, _), tag) => 623 DVal (x, n, c, e as (EAbs _, _), tag) =>
624 ((*Print.prefaces "ConstArgs[2]" [("d", CorePrint.p_decl CoreEnv.empty d),
625 ("ca", Print.PD.string (Int.toString (calcConstArgs (IS.singleton n) e)))];*)
610 IM.insert (funcs, n, {name = x, 626 IM.insert (funcs, n, {name = x,
611 args = KM.empty, 627 args = KM.empty,
612 body = e, 628 body = e,
613 typ = c, 629 typ = c,
614 tag = tag, 630 tag = tag,
615 constArgs = calcConstArgs (IS.singleton n) e}) 631 constArgs = calcConstArgs (IS.singleton n) e}))
616 | DVal (_, n, _, (ENamed n', _), _) => 632 | DVal (_, n, _, (ENamed n', _), _) =>
617 (case IM.find (funcs, n') of 633 (case IM.find (funcs, n') of
618 NONE => funcs 634 NONE => funcs
619 | SOME v => IM.insert (funcs, n, v)) 635 | SOME v => IM.insert (funcs, n, v))
620 | _ => funcs 636 | _ => funcs