# HG changeset patch # User Adam Chlipala # Date 1226173372 18000 # Node ID ffa18975e661e3309b918aae03f9ad3c2b4a90fa # Parent 6ee1c761818f489915b4773e288b0a915d58cae1 Broaden set of possible especializations diff -r 6ee1c761818f -r ffa18975e661 src/core_util.sig --- a/src/core_util.sig Sat Nov 08 13:15:00 2008 -0500 +++ b/src/core_util.sig Sat Nov 08 14:42:52 2008 -0500 @@ -73,6 +73,8 @@ end structure Exp : sig + val compare : Core.exp * Core.exp -> order + datatype binder = RelC of string * Core.kind | NamedC of string * int * Core.kind * Core.con option @@ -108,6 +110,12 @@ con : Core.con' -> bool, exp : Core.exp' -> bool} -> Core.exp -> bool + val existsB : {kind : Core.kind' -> bool, + con : 'context * Core.con' -> bool, + exp : 'context * Core.exp' -> bool, + bind : 'context * binder -> 'context} + -> 'context -> Core.exp -> bool + val foldMap : {kind : Core.kind' * 'state -> Core.kind' * 'state, con : Core.con' * 'state -> Core.con' * 'state, exp : Core.exp' * 'state -> Core.exp' * 'state} diff -r 6ee1c761818f -r ffa18975e661 src/core_util.sml --- a/src/core_util.sml Sat Nov 08 13:15:00 2008 -0500 +++ b/src/core_util.sml Sat Nov 08 14:42:52 2008 -0500 @@ -331,6 +331,149 @@ structure Exp = struct +open Order + +fun pcCompare (pc1, pc2) = + case (pc1, pc2) of + (PConVar n1, PConVar n2) => Int.compare (n1, n2) + | (PConVar _, _) => LESS + | (_, PConVar _) => GREATER + + | (PConFfi {mod = m1, datatyp = d1, con = c1, ...}, + PConFfi {mod = m2, datatyp = d2, con = c2, ...}) => + join (String.compare (m1, m2), + fn () => join (String.compare (d1, d2), + fn () => String.compare (c1, c2))) + +fun pCompare ((p1, _), (p2, _)) = + case (p1, p2) of + (PWild, PWild) => EQUAL + | (PWild, _) => LESS + | (_, PWild) => GREATER + + | (PVar _, PVar _) => EQUAL + | (PVar _, _) => LESS + | (_, PVar _) => GREATER + + | (PPrim p1, PPrim p2) => Prim.compare (p1, p2) + | (PPrim _, _) => LESS + | (_, PPrim _) => GREATER + + | (PCon (_, pc1, _, po1), PCon (_, pc2, _, po2)) => + join (pcCompare (pc1, pc2), + fn () => joinO pCompare (po1, po2)) + | (PCon _, _) => LESS + | (_, PCon _) => GREATER + + | (PRecord xps1, PRecord xps2) => + joinL (fn ((x1, p1, _), (x2, p2, _)) => + join (String.compare (x1, x2), + fn () => pCompare (p1, p2))) (xps1, xps2) + +fun compare ((e1, _), (e2, _)) = + case (e1, e2) of + (EPrim p1, EPrim p2) => Prim.compare (p1, p2) + | (EPrim _, _) => LESS + | (_, EPrim _) => GREATER + + | (ERel n1, ERel n2) => Int.compare (n1, n2) + | (ERel _, _) => LESS + | (_, ERel _) => GREATER + + | (ENamed n1, ENamed n2) => Int.compare (n1, n2) + | (ENamed _, _) => LESS + | (_, ENamed _) => GREATER + + | (ECon (_, pc1, _, eo1), ECon (_, pc2, _, eo2)) => + join (pcCompare (pc1, pc2), + fn () => joinO compare (eo1, eo2)) + | (ECon _, _) => LESS + | (_, ECon _) => GREATER + + | (EFfi (f1, x1), EFfi (f2, x2)) => + join (String.compare (f1, f2), + fn () => String.compare (x1, x2)) + | (EFfi _, _) => LESS + | (_, EFfi _) => GREATER + + | (EFfiApp (f1, x1, es1), EFfiApp (f2, x2, es2)) => + join (String.compare (f1, f2), + fn () => join (String.compare (x1, x2), + fn () => joinL compare (es1, es2))) + | (EFfiApp _, _) => LESS + | (_, EFfiApp _) => GREATER + + | (EApp (f1, x1), EApp (f2, x2)) => + join (compare (f1, f2), + fn () => compare (x1, x2)) + | (EApp _, _) => LESS + | (_, EApp _) => GREATER + + | (EAbs (_, _, _, e1), EAbs (_, _, _, e2)) => compare (e1, e2) + | (EAbs _, _) => LESS + | (_, EAbs _) => GREATER + + | (ECApp (f1, x1), ECApp (f2, x2)) => + join (compare (f1, f2), + fn () => Con.compare (x1, x2)) + | (ECApp _, _) => LESS + | (_, ECApp _) => GREATER + + | (ECAbs (_, _, e1), ECAbs (_, _, e2)) => compare (e1, e2) + | (ECAbs _, _) => LESS + | (_, ECAbs _) => GREATER + + | (ERecord xes1, ERecord xes2) => + joinL (fn ((x1, e1, _), (x2, e2, _)) => + join (Con.compare (x1, x2), + fn () => compare (e1, e2))) (xes1, xes2) + | (ERecord _, _) => LESS + | (_, ERecord _) => GREATER + + | (EField (e1, c1, _), EField (e2, c2, _)) => + join (compare (e1, e2), + fn () => Con.compare (c1, c2)) + | (EField _, _) => LESS + | (_, EField _) => GREATER + + | (EConcat (x1, _, y1, _), EConcat (x2, _, y2, _)) => + join (compare (x1, x2), + fn () => compare (y1, y2)) + | (EConcat _, _) => LESS + | (_, EConcat _) => GREATER + + | (ECut (e1, c1, _), ECut (e2, c2, _)) => + join (compare (e1, e2), + fn () => Con.compare (c1, c2)) + | (ECut _, _) => LESS + | (_, ECut _) => GREATER + + | (EFold _, EFold _) => EQUAL + | (EFold _, _) => LESS + | (_, EFold _) => GREATER + + | (ECase (e1, pes1, _), ECase (e2, pes2, _)) => + join (compare (e1, e2), + fn () => joinL (fn ((p1, e1), (p2, e2)) => + join (pCompare (p1, p2), + fn () => compare (e1, e2))) (pes1, pes2)) + | (ECase _, _) => LESS + | (_, ECase _) => GREATER + + | (EWrite e1, EWrite e2) => compare (e1, e2) + | (EWrite _, _) => LESS + | (_, EWrite _) => GREATER + + | (EClosure (n1, es1), EClosure (n2, es2)) => + join (Int.compare (n1, n2), + fn () => joinL compare (es1, es2)) + | (EClosure _, _) => LESS + | (_, EClosure _) => GREATER + + | (ELet (_, _, x1, e1), ELet (_, _, x2, e2)) => + join (compare (x1, x2), + fn () => compare (e1, e2)) + datatype binder = RelC of string * kind | NamedC of string * int * kind * con option @@ -585,6 +728,26 @@ S.Return _ => true | S.Continue _ => false +fun existsB {kind, con, exp, bind} ctx k = + case mapfoldB {kind = fn k => fn () => + if kind k then + S.Return () + else + S.Continue (k, ()), + con = fn ctx => fn c => fn () => + if con (ctx, c) then + S.Return () + else + S.Continue (c, ()), + exp = fn ctx => fn e => fn () => + if exp (ctx, e) then + S.Return () + else + S.Continue (e, ()), + bind = bind} ctx k () of + S.Return _ => true + | S.Continue _ => false + fun foldMap {kind, con, exp} s e = case mapfold {kind = fn k => fn s => S.Continue (kind (k, s)), con = fn c => fn s => S.Continue (con (c, s)), diff -r 6ee1c761818f -r ffa18975e661 src/especialize.sml --- a/src/especialize.sml Sat Nov 08 13:15:00 2008 -0500 +++ b/src/especialize.sml Sat Nov 08 14:42:52 2008 -0500 @@ -32,39 +32,57 @@ structure E = CoreEnv structure U = CoreUtil -datatype skey = - Named of int - | App of skey * skey +type skey = exp structure K = struct -type ord_key = skey list -fun compare' (k1, k2) = - case (k1, k2) of - (Named n1, Named n2) => Int.compare (n1, n2) - | (Named _, _) => LESS - | (_, Named _) => GREATER - - | (App (x1, y1), App (x2, y2)) => Order.join (compare' (x1, x2), fn () => compare' (y1, y2)) - -val compare = Order.joinL compare' +type ord_key = exp list +val compare = Order.joinL U.Exp.compare end structure KM = BinaryMapFn(K) structure IM = IntBinaryMap -fun skeyIn (e, _) = +val sizeOf = U.Exp.fold {kind = fn (_, n) => n, + con = fn (_, n) => n, + exp = fn (_, n) => n + 1} + 0 + +val isOpen = U.Exp.existsB {kind = fn _ => false, + con = fn ((nc, _), c) => + case c of + CRel n => n >= nc + | _ => false, + exp = fn ((_, ne), e) => + case e of + ERel n => n >= ne + | _ => false, + bind = fn ((nc, ne), b) => + case b of + U.Exp.RelC _ => (nc + 1, ne) + | U.Exp.RelE _ => (nc, ne + 1) + | _ => (nc, ne)} + (0, 0) + +fun baseBad (e, _) = case e of - ENamed n => SOME (Named n) - | EApp (e1, e2) => - (case (skeyIn e1, skeyIn e2) of - (SOME k1, SOME k2) => SOME (App (k1, k2)) - | _ => NONE) - | _ => NONE + EAbs (_, _, _, e) => sizeOf e > 20 + | ENamed _ => false + | _ => true -fun skeyOut (k, loc) = - case k of - Named n => (ENamed n, loc) - | App (k1, k2) => (EApp (skeyOut (k1, loc), skeyOut (k2, loc)), loc) +fun isBad e = + case e of + (ERecord xes, _) => + length xes > 10 + orelse List.exists (fn (_, e, _) => baseBad e) xes + | _ => baseBad e + +fun skeyIn e = + if isBad e orelse isOpen e then + NONE + else + SOME e + +fun skeyOut e = e type func = { name : string, @@ -126,7 +144,7 @@ (_, _, []) => SOME (body, typ) | (EAbs (_, _, _, body'), TFun (_, typ'), x :: xs) => let - val body'' = E.subExpInExp (0, skeyOut (x, #2 body)) body' + val body'' = E.subExpInExp (0, skeyOut x) body' in (*Print.prefaces "espec" [("body'", CorePrint.p_exp CoreEnv.empty body'), ("body''", CorePrint.p_exp CoreEnv.empty body'')];*) diff -r 6ee1c761818f -r ffa18975e661 src/order.sig --- a/src/order.sig Sat Nov 08 13:15:00 2008 -0500 +++ b/src/order.sig Sat Nov 08 14:42:52 2008 -0500 @@ -31,5 +31,6 @@ val join : order * (unit -> order) -> order val joinL : ('a * 'b -> order) -> 'a list * 'b list -> order - + val joinO : ('a * 'b -> order) -> 'a option * 'b option -> order + end diff -r 6ee1c761818f -r ffa18975e661 src/order.sml --- a/src/order.sml Sat Nov 08 13:15:00 2008 -0500 +++ b/src/order.sml Sat Nov 08 14:42:52 2008 -0500 @@ -42,4 +42,12 @@ join (f (h1, h2), fn () => joinL f (t1, t2)) | (_ :: _, nil) => GREATER +fun joinO f (v1, v2) = + case (v1, v2) of + (NONE, NONE) => EQUAL + | (NONE, _) => LESS + | (_, NONE) => GREATER + + | (SOME v1, SOME v2) => f (v1, v2) + end diff -r 6ee1c761818f -r ffa18975e661 src/prim.sig --- a/src/prim.sig Sat Nov 08 13:15:00 2008 -0500 +++ b/src/prim.sig Sat Nov 08 14:42:52 2008 -0500 @@ -36,5 +36,6 @@ val p_t_GCC : t Print.printer val equal : t * t -> bool + val compare : t * t -> order end diff -r 6ee1c761818f -r ffa18975e661 src/prim.sml --- a/src/prim.sml Sat Nov 08 13:15:00 2008 -0500 +++ b/src/prim.sml Sat Nov 08 14:42:52 2008 -0500 @@ -67,4 +67,16 @@ | _ => false +fun compare (p1, p2) = + case (p1, p2) of + (Int n1, Int n2) => Int64.compare (n1, n2) + | (Int _, _) => LESS + | (_, Int _) => GREATER + + | (Float n1, Float n2) => Real64.compare (n1, n2) + | (Float _, _) => LESS + | (_, Float _) => GREATER + + | (String n1, String n2) => String.compare (n1, n2) + end