changeset 479:ffa18975e661

Broaden set of possible especializations
author Adam Chlipala <adamc@hcoop.net>
date Sat, 08 Nov 2008 14:42:52 -0500
parents 6ee1c761818f
children 40c737913075
files src/core_util.sig src/core_util.sml src/especialize.sml src/order.sig src/order.sml src/prim.sig src/prim.sml
diffstat 7 files changed, 237 insertions(+), 26 deletions(-) [+]
line wrap: on
line diff
--- a/src/core_util.sig	Sat Nov 08 13:15:00 2008 -0500
+++ b/src/core_util.sig	Sat Nov 08 14:42:52 2008 -0500
@@ -73,6 +73,8 @@
 end
 
 structure Exp : sig
+    val compare : Core.exp * Core.exp -> order
+
     datatype binder =
              RelC of string * Core.kind
            | NamedC of string * int * Core.kind * Core.con option
@@ -108,6 +110,12 @@
                   con : Core.con' -> bool,
                   exp : Core.exp' -> bool} -> Core.exp -> bool
 
+    val existsB : {kind : Core.kind' -> bool,
+                   con : 'context * Core.con' -> bool,
+                   exp : 'context * Core.exp' -> bool,
+                   bind : 'context * binder -> 'context}
+                  -> 'context -> Core.exp -> bool
+
     val foldMap : {kind : Core.kind' * 'state -> Core.kind' * 'state,
                    con : Core.con' * 'state -> Core.con' * 'state,
                    exp : Core.exp' * 'state -> Core.exp' * 'state}
--- a/src/core_util.sml	Sat Nov 08 13:15:00 2008 -0500
+++ b/src/core_util.sml	Sat Nov 08 14:42:52 2008 -0500
@@ -331,6 +331,149 @@
 
 structure Exp = struct
 
+open Order
+
+fun pcCompare (pc1, pc2) =
+    case (pc1, pc2) of
+        (PConVar n1, PConVar n2) => Int.compare (n1, n2)
+      | (PConVar _, _) => LESS
+      | (_, PConVar _) => GREATER
+
+      | (PConFfi {mod = m1, datatyp = d1, con = c1, ...},
+         PConFfi {mod = m2, datatyp = d2, con = c2, ...}) =>
+        join (String.compare (m1, m2),
+              fn () => join (String.compare (d1, d2),
+                             fn () => String.compare (c1, c2)))
+
+fun pCompare ((p1, _), (p2, _)) =
+    case (p1, p2) of
+        (PWild, PWild) => EQUAL
+      | (PWild, _) => LESS
+      | (_, PWild) => GREATER
+
+      | (PVar _, PVar _) => EQUAL
+      | (PVar _, _) => LESS
+      | (_, PVar _) => GREATER
+
+      | (PPrim p1, PPrim p2) => Prim.compare (p1, p2)
+      | (PPrim _, _) => LESS
+      | (_, PPrim _) => GREATER
+
+      | (PCon (_, pc1, _, po1), PCon (_, pc2, _, po2)) =>
+        join (pcCompare (pc1, pc2),
+              fn () => joinO pCompare (po1, po2))
+      | (PCon _, _) => LESS
+      | (_, PCon _) => GREATER
+
+      | (PRecord xps1, PRecord xps2) =>
+        joinL (fn ((x1, p1, _), (x2, p2, _)) =>
+                  join (String.compare (x1, x2),
+                        fn () => pCompare (p1, p2))) (xps1, xps2)
+
+fun compare ((e1, _), (e2, _)) =
+    case (e1, e2) of
+        (EPrim p1, EPrim p2) => Prim.compare (p1, p2)
+      | (EPrim _, _) => LESS
+      | (_, EPrim _) => GREATER
+
+      | (ERel n1, ERel n2) => Int.compare (n1, n2)
+      | (ERel _, _) => LESS
+      | (_, ERel _) => GREATER
+
+      | (ENamed n1, ENamed n2) => Int.compare (n1, n2)
+      | (ENamed _, _) => LESS
+      | (_, ENamed _) => GREATER
+
+      | (ECon (_, pc1, _, eo1), ECon (_, pc2, _, eo2)) =>
+        join (pcCompare (pc1, pc2),
+              fn () => joinO compare (eo1, eo2))
+      | (ECon _, _) => LESS
+      | (_, ECon _) => GREATER
+
+      | (EFfi (f1, x1), EFfi (f2, x2)) =>
+        join (String.compare (f1, f2),
+              fn () => String.compare (x1, x2))
+      | (EFfi _, _) => LESS
+      | (_, EFfi _) => GREATER
+
+      | (EFfiApp (f1, x1, es1), EFfiApp (f2, x2, es2)) =>
+        join (String.compare (f1, f2),
+           fn () => join (String.compare (x1, x2),
+                       fn () => joinL compare (es1, es2)))
+      | (EFfiApp _, _) => LESS
+      | (_, EFfiApp _) => GREATER
+
+      | (EApp (f1, x1), EApp (f2, x2)) =>
+        join (compare (f1, f2),
+              fn () => compare (x1, x2))
+      | (EApp _, _) => LESS
+      | (_, EApp _) => GREATER
+
+      | (EAbs (_, _, _, e1), EAbs (_, _, _, e2)) => compare (e1, e2)
+      | (EAbs _, _) => LESS
+      | (_, EAbs _) => GREATER
+
+      | (ECApp (f1, x1), ECApp (f2, x2)) =>
+        join (compare (f1, f2),
+           fn () => Con.compare (x1, x2))
+      | (ECApp _, _) => LESS
+      | (_, ECApp _) => GREATER
+
+      | (ECAbs (_, _, e1), ECAbs (_, _, e2)) => compare (e1, e2)
+      | (ECAbs _, _) => LESS
+      | (_, ECAbs _) => GREATER
+
+      | (ERecord xes1, ERecord xes2) =>
+        joinL (fn ((x1, e1, _), (x2, e2, _)) =>
+                  join (Con.compare (x1, x2),
+                        fn () => compare (e1, e2))) (xes1, xes2)
+      | (ERecord _, _) => LESS
+      | (_, ERecord _) => GREATER
+
+      | (EField (e1, c1, _), EField (e2, c2, _)) =>
+        join (compare (e1, e2),
+              fn () => Con.compare (c1, c2))
+      | (EField _, _) => LESS
+      | (_, EField _) => GREATER
+
+      | (EConcat (x1, _, y1, _), EConcat (x2, _, y2, _)) =>
+        join (compare (x1, x2),
+              fn () => compare (y1, y2))
+      | (EConcat _, _) => LESS
+      | (_, EConcat _) => GREATER
+
+      | (ECut (e1, c1, _), ECut (e2, c2, _)) =>
+        join (compare (e1, e2),
+              fn () => Con.compare (c1, c2))
+      | (ECut _, _) => LESS
+      | (_, ECut _) => GREATER
+
+      | (EFold _, EFold _) => EQUAL
+      | (EFold _, _) => LESS
+      | (_, EFold _) => GREATER
+
+      | (ECase (e1, pes1, _), ECase (e2, pes2, _)) =>
+        join (compare (e1, e2),
+              fn () => joinL (fn ((p1, e1), (p2, e2)) =>
+                                 join (pCompare (p1, p2),
+                                       fn () => compare (e1, e2))) (pes1, pes2))
+      | (ECase _, _) => LESS
+      | (_, ECase _) => GREATER
+
+      | (EWrite e1, EWrite e2) => compare (e1, e2)
+      | (EWrite _, _) => LESS
+      | (_, EWrite _) => GREATER
+
+      | (EClosure (n1, es1), EClosure (n2, es2)) =>
+        join (Int.compare (n1, n2),
+              fn () => joinL compare (es1, es2))
+      | (EClosure _, _) => LESS
+      | (_, EClosure _) => GREATER
+
+      | (ELet (_, _, x1, e1), ELet (_, _, x2, e2)) =>
+        join (compare (x1, x2),
+              fn () => compare (e1, e2))
+
 datatype binder =
          RelC of string * kind
        | NamedC of string * int * kind * con option
@@ -585,6 +728,26 @@
         S.Return _ => true
       | S.Continue _ => false
 
+fun existsB {kind, con, exp, bind} ctx k =
+    case mapfoldB {kind = fn k => fn () =>
+                                     if kind k then
+                                         S.Return ()
+                                     else
+                                         S.Continue (k, ()),
+                   con = fn ctx => fn c => fn () =>
+                                              if con (ctx, c) then
+                                                  S.Return ()
+                                              else
+                                                  S.Continue (c, ()),
+                   exp = fn ctx => fn e => fn () =>
+                                              if exp (ctx, e) then
+                                                  S.Return ()
+                                              else
+                                                  S.Continue (e, ()),
+                   bind = bind} ctx k () of
+        S.Return _ => true
+      | S.Continue _ => false
+
 fun foldMap {kind, con, exp} s e =
     case mapfold {kind = fn k => fn s => S.Continue (kind (k, s)),
                   con = fn c => fn s => S.Continue (con (c, s)),
--- a/src/especialize.sml	Sat Nov 08 13:15:00 2008 -0500
+++ b/src/especialize.sml	Sat Nov 08 14:42:52 2008 -0500
@@ -32,39 +32,57 @@
 structure E = CoreEnv
 structure U = CoreUtil
 
-datatype skey =
-         Named of int
-       | App of skey * skey
+type skey = exp
 
 structure K = struct
-type ord_key = skey list
-fun compare' (k1, k2) =
-    case (k1, k2) of
-        (Named n1, Named n2) => Int.compare (n1, n2)
-      | (Named _, _) => LESS
-      | (_, Named _) => GREATER
-
-      | (App (x1, y1), App (x2, y2)) => Order.join (compare' (x1, x2), fn () => compare' (y1, y2))
-
-val compare = Order.joinL compare'
+type ord_key = exp list
+val compare = Order.joinL U.Exp.compare
 end
 
 structure KM = BinaryMapFn(K)
 structure IM = IntBinaryMap
 
-fun skeyIn (e, _) =
+val sizeOf = U.Exp.fold {kind = fn (_, n) => n,
+                         con = fn (_, n) => n,
+                         exp = fn (_, n) => n + 1}
+                        0
+
+val isOpen = U.Exp.existsB {kind = fn _ => false,
+                            con = fn ((nc, _), c) =>
+                                    case c of
+                                        CRel n => n >= nc
+                                      | _ => false,
+                            exp = fn ((_, ne), e) =>
+                                     case e of
+                                         ERel n => n >= ne
+                                       | _ => false,
+                            bind = fn ((nc, ne), b) =>
+                                      case b of
+                                          U.Exp.RelC _ => (nc + 1, ne)
+                                        | U.Exp.RelE _ => (nc, ne + 1)
+                                        | _ => (nc, ne)}
+             (0, 0)
+
+fun baseBad (e, _) =
     case e of
-        ENamed n => SOME (Named n)
-      | EApp (e1, e2) =>
-        (case (skeyIn e1, skeyIn e2) of
-             (SOME k1, SOME k2) => SOME (App (k1, k2))
-           | _ => NONE)
-      | _ => NONE
+        EAbs (_, _, _, e) => sizeOf e > 20
+      | ENamed _ => false
+      | _ => true
 
-fun skeyOut (k, loc) =
-    case k of
-        Named n => (ENamed n, loc)
-      | App (k1, k2) => (EApp (skeyOut (k1, loc), skeyOut (k2, loc)), loc)
+fun isBad e =
+    case e of
+        (ERecord xes, _) =>
+        length xes > 10
+        orelse List.exists (fn (_, e, _) => baseBad e) xes
+      | _ => baseBad e
+
+fun skeyIn e =
+    if isBad e orelse isOpen e then
+        NONE
+    else
+        SOME e
+
+fun skeyOut e = e
 
 type func = {
      name : string,
@@ -126,7 +144,7 @@
                                 (_, _, []) => SOME (body, typ)
                               | (EAbs (_, _, _, body'), TFun (_, typ'), x :: xs) =>
                                 let
-                                    val body'' = E.subExpInExp (0, skeyOut (x, #2 body)) body'
+                                    val body'' = E.subExpInExp (0, skeyOut x) body'
                                 in
                                     (*Print.prefaces "espec" [("body'", CorePrint.p_exp CoreEnv.empty body'),
                                                             ("body''", CorePrint.p_exp CoreEnv.empty body'')];*)
--- a/src/order.sig	Sat Nov 08 13:15:00 2008 -0500
+++ b/src/order.sig	Sat Nov 08 14:42:52 2008 -0500
@@ -31,5 +31,6 @@
     
     val join : order * (unit -> order) -> order
     val joinL : ('a * 'b -> order) -> 'a list * 'b list -> order
-                                                       
+    val joinO : ('a * 'b -> order) -> 'a option * 'b option -> order
+
 end
--- a/src/order.sml	Sat Nov 08 13:15:00 2008 -0500
+++ b/src/order.sml	Sat Nov 08 14:42:52 2008 -0500
@@ -42,4 +42,12 @@
         join (f (h1, h2), fn () => joinL f (t1, t2))
       | (_ :: _, nil) => GREATER
 
+fun joinO f (v1, v2) =
+    case (v1, v2) of
+        (NONE, NONE) => EQUAL
+      | (NONE, _) => LESS
+      | (_, NONE) => GREATER
+
+      | (SOME v1, SOME v2) => f (v1, v2)
+
 end
--- a/src/prim.sig	Sat Nov 08 13:15:00 2008 -0500
+++ b/src/prim.sig	Sat Nov 08 14:42:52 2008 -0500
@@ -36,5 +36,6 @@
     val p_t_GCC : t Print.printer
 
     val equal : t * t -> bool
+    val compare : t * t -> order
 
 end
--- a/src/prim.sml	Sat Nov 08 13:15:00 2008 -0500
+++ b/src/prim.sml	Sat Nov 08 14:42:52 2008 -0500
@@ -67,4 +67,16 @@
 
       | _ => false
 
+fun compare (p1, p2) =
+    case (p1, p2) of
+        (Int n1, Int n2) => Int64.compare (n1, n2)
+      | (Int _, _) => LESS
+      | (_, Int _) => GREATER
+
+      | (Float n1, Float n2) => Real64.compare (n1, n2)
+      | (Float _, _) => LESS
+      | (_, Float _) => GREATER
+
+      | (String n1, String n2) => String.compare (n1, n2)
+
 end