diff src/especialize.sml @ 1080:a4979e31e4bf

Another try at reasonable Especialize, this time with a custom traversal
author Adam Chlipala <adamc@hcoop.net>
date Sun, 20 Dec 2009 15:17:43 -0500
parents d069b193ed6b
children 2eb585274501
line wrap: on
line diff
--- a/src/especialize.sml	Tue Dec 15 19:26:52 2009 -0500
+++ b/src/especialize.sml	Sun Dec 20 15:17:43 2009 -0500
@@ -1,4 +1,4 @@
-(* Copyright (c) 2008, Adam Chlipala
+(* Copyright (c) 2008-2009, Adam Chlipala
  * All rights reserved.
  *
  * Redistribution and use in source and binary forms, with or without
@@ -62,6 +62,7 @@
 val isPoly = U.Decl.exists {kind = fn _ => false,
                             con = fn _ => false,
                             exp = fn ECAbs _ => true
+                                   | EKAbs _ => true
                                    | _ => false,
                             decl = fn _ => false}
 
@@ -108,7 +109,7 @@
      maxName : int,
      funcs : func IM.map,
      decls : (string * int * con * exp * string) list,
-     specialized : bool IM.map
+     specialized : IS.set
 }
 
 fun default (_, x, st) = (x, st)
@@ -120,36 +121,162 @@
 
 val mayNotSpec = ref SS.empty
 
-fun specialize' specialized file =
+fun specialize' (funcs, specialized) file =
     let
         fun bind (env, b) =
             case b of
                 U.Decl.RelE xt => xt :: env
               | _ => env
 
-        fun exp (env, e, st : state) =
+        fun exp (env, e as (_, loc), st : state) =
             let
                 (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty
                                                                      (e, ErrorMsg.dummySpan))]*)
 
-                fun getApp e =
+                fun getApp (e, _) =
                     case e of
                         ENamed f => SOME (f, [])
                       | EApp (e1, e2) =>
-                        (case getApp (#1 e1) of
+                        (case getApp e1 of
                              NONE => NONE
                            | SOME (f, xs) => SOME (f, xs @ [e2]))
                       | _ => NONE
+
+                val getApp = fn e => case getApp e of
+                                         v as SOME (_, _ :: _) => v
+                                       | _ => NONE
+
+                fun default () =
+                    case #1 e of
+                        EPrim _ => (e, st)
+                      | ERel _ => (e, st)
+                      | ENamed _ => (e, st)
+                      | ECon (_, _, _, NONE) => (e, st)
+                      | ECon (dk, pc, cs, SOME e) =>
+                        let
+                            val (e, st) = exp (env, e, st)
+                        in
+                            ((ECon (dk, pc, cs, SOME e), loc), st)
+                        end
+                      | EFfi _ => (e, st)
+                      | EFfiApp (m, x, es) =>
+                        let
+                            val (es, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st es
+                        in
+                            ((EFfiApp (m, x, es), loc), st)
+                        end
+                      | EApp (e1, e2) =>
+                        let
+                            val (e1, st) = exp (env, e1, st)
+                            val (e2, st) = exp (env, e2, st)
+                        in
+                            ((EApp (e1, e2), loc), st)
+                        end
+                      | EAbs (x, d, r, e) =>
+                        let
+                            val (e, st) = exp ((x, d) :: env, e, st)
+                        in
+                            ((EAbs (x, d, r, e), loc), st)
+                        end
+                      | ECApp (e, c) =>
+                        let
+                            val (e, st) = exp (env, e, st)
+                        in
+                            ((ECApp (e, c), loc), st)
+                        end
+                      | ECAbs _ => raise Fail "Especialize: Impossible ECAbs"
+                      | EKAbs _ => raise Fail "Especialize: Impossible EKAbs"
+                      | EKApp (e, k) =>
+                        let
+                            val (e, st) = exp (env, e, st)
+                        in
+                            ((EKApp (e, k), loc), st)
+                        end
+                      | ERecord fs =>
+                        let
+                            val (fs, st) = ListUtil.foldlMap (fn ((c1, e, c2), st) =>
+                                                                 let
+                                                                     val (e, st) = exp (env, e, st)
+                                                                 in
+                                                                     ((c1, e, c2), st)
+                                                                 end) st fs
+                        in
+                            ((ERecord fs, loc), st)
+                        end
+                      | EField (e, c, cs) =>
+                        let
+                            val (e, st) = exp (env, e, st)
+                        in
+                            ((EField (e, c, cs), loc), st)
+                        end
+                      | EConcat (e1, c1, e2, c2) =>
+                        let
+                            val (e1, st) = exp (env, e1, st)
+                            val (e2, st) = exp (env, e2, st)
+                        in
+                            ((EConcat (e1, c1, e2, c2), loc), st)
+                        end
+                      | ECut (e, c, cs) =>
+                        let
+                            val (e, st) = exp (env, e, st)
+                        in
+                            ((ECut (e, c, cs), loc), st)
+                        end
+                      | ECutMulti (e, c, cs) =>
+                        let
+                            val (e, st) = exp (env, e, st)
+                        in
+                            ((ECutMulti (e, c, cs), loc), st)
+                        end
+
+                      | ECase (e, pes, cs) =>
+                        let
+                            val (e, st) = exp (env, e, st)
+                            val (pes, st) = ListUtil.foldlMap (fn ((p, e), st) =>
+                                                                  let
+                                                                      val (e, st) = exp (E.patBindsL p @ env, e, st)
+                                                                  in
+                                                                      ((p, e), st)
+                                                                  end) st pes
+                        in
+                            ((ECase (e, pes, cs), loc), st)
+                        end
+
+                      | EWrite e =>
+                        let
+                            val (e, st) = exp (env, e, st)
+                        in
+                            ((EWrite e, loc), st)
+                        end
+                      | EClosure (n, es) =>
+                        let
+                            val (es, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st es
+                        in
+                            ((EClosure (n, es), loc), st)
+                        end
+                      | ELet (x, t, e1, e2) =>
+                        let
+                            val (e1, st) = exp (env, e1, st)
+                            val (e2, st) = exp ((x, t) :: env, e2, st)
+                        in
+                            ((ELet (x, t, e1, e2), loc), st)
+                        end
+                      | EServerCall (n, es, t) =>
+                        let
+                            val (es, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st es
+                        in
+                            ((EServerCall (n, es, t), loc), st)
+                        end
             in
                 case getApp e of
-                    NONE => ((*Print.prefaces "No" [("e", CorePrint.p_exp CoreEnv.empty
-                                                                        (e, ErrorMsg.dummySpan))];*)
-                             (e, st))
+                    NONE => default ()
                   | SOME (f, xs) =>
                     case IM.find (#funcs st, f) of
-                        NONE => (e, st)
+                        NONE => default ()
                       | SOME {name, args, body, typ, tag} =>
                         let
+                            val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs
+
                             (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty
                                                                                       (e, ErrorMsg.dummySpan))]*)
 
@@ -166,7 +293,7 @@
                                                                       | _ => false}
                             val loc = ErrorMsg.dummySpan
 
-                            fun findSplit av (xs, typ, fxs, fvs) =
+                            fun findSplit av (xs, typ, fxs, fvs, fin) =
                                 case (#1 typ, xs) of
                                     (TFun (dom, ran), e :: xs') =>
                                     let
@@ -180,25 +307,27 @@
                                             findSplit av (xs',
                                                           ran,
                                                           e :: fxs,
-                                                          IS.union (fvs, freeVars e))
+                                                          IS.union (fvs, freeVars e),
+                                                          fin orelse functionInside dom)
                                         else
-                                            (rev fxs, xs, fvs)
+                                            (rev fxs, xs, fvs, fin)
                                     end
-                                  | _ => (rev fxs, xs, fvs)
+                                  | _ => (rev fxs, xs, fvs, fin)
 
-                            val (fxs, xs, fvs) = findSplit true (xs, typ, [], IS.empty)
+                            val (fxs, xs, fvs, fin) = findSplit true (xs, typ, [], IS.empty, false)
 
                             val fxs' = map (squish (IS.listItems fvs)) fxs
                         in
                             (*Print.preface ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs');*)
-                            if List.all (fn (ERel _, _) => true
-                                          | _ => false) fxs'
+                            if not fin
+                               orelse List.all (fn (ERel _, _) => true
+                                                 | _ => false) fxs'
                                orelse (IS.numItems fvs >= length fxs
                                        andalso IS.exists (fn n => functionInside (#2 (List.nth (env, n)))) fvs) then
-                                (e, st)
+                                default ()
                             else
                                 case (KM.find (args, fxs'),
-                                      SS.member (!mayNotSpec, name) orelse IM.find (#specialized st, f) = SOME true) of
+                                      SS.member (!mayNotSpec, name) orelse IS.member (#specialized st, f)) of
                                     (SOME f', _) =>
                                     let
                                         val e = (ENamed f', loc)
@@ -209,12 +338,12 @@
                                     in
                                         (*Print.prefaces "Brand new (reuse)"
                                                        [("e'", CorePrint.p_exp CoreEnv.empty e)];*)
-                                        (#1 e, st)
+                                        (e, st)
                                     end
                                   | (_, true) => ((*Print.prefaces ("No(" ^ name ^ ")")
                                                                  [("fxs'",
                                                                    Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')];*)
-                                                  (e, st))
+                                                  default ())
                                   | (NONE, false) =>
                                     let
                                         (*val () = Print.prefaces "New one"
@@ -240,7 +369,7 @@
                                               | _ => NONE
                                     in
                                         case subBody (body, typ, fxs') of
-                                            NONE => (e, st)
+                                            NONE => default ()
                                           | SOME (body', typ') =>
                                             let
                                                 val f' = #maxName st
@@ -251,16 +380,11 @@
                                                                                       typ = typ,
                                                                                       tag = tag})
 
-                                                val specialized = IM.insert (#specialized st, f', false)
-                                                val specialized = case IM.find (specialized, f) of
-                                                                      NONE => specialized
-                                                                    | SOME _ => IM.insert (specialized, f, true)
-
                                                 val st = {
                                                     maxName = f' + 1,
                                                     funcs = funcs,
                                                     decls = #decls st,
-                                                    specialized = specialized
+                                                    specialized = IS.add (#specialized st, f')
                                                 }
 
                                                 (*val () = Print.prefaces "specExp"
@@ -280,9 +404,9 @@
                                                                                  end)
                                                                              (body', typ') fvs
                                                 val mns = !mayNotSpec
-                                                val () = mayNotSpec := SS.add (mns, name)
-                                                (*val () = Print.preface ("body'", CorePrint.p_exp CoreEnv.empty body')*)
-                                                val (body', st) = specExp env st body'
+                                                (*val () = mayNotSpec := SS.add (mns, name)*)
+                                                (*val () = Print.preface ("PRE", CorePrint.p_exp CoreEnv.empty body')*)
+                                                val (body', st) = exp (env, body', st)
                                                 val () = mayNotSpec := mns
 
                                                 val e' = (ENamed f', loc)
@@ -292,10 +416,10 @@
                                                                e' xs
                                                 (*val () = Print.prefaces "Brand new"
                                                                         [("e'", CorePrint.p_exp CoreEnv.empty e'),
-                                                                         ("e", CorePrint.p_exp CoreEnv.empty (e, loc)),
+                                                                         ("e", CorePrint.p_exp CoreEnv.empty e),
                                                                          ("body'", CorePrint.p_exp CoreEnv.empty body')]*)
                                             in
-                                                (#1 e',
+                                                (e',
                                                  {maxName = #maxName st,
                                                   funcs = #funcs st,
                                                   decls = (name, f', typ', body', tag) :: #decls st,
@@ -305,10 +429,6 @@
                         end
             end
 
-        and specExp env = U.Exp.foldMapB {kind = default, con = default, exp = exp, bind = bind} env
-
-        val specDecl = U.Decl.foldMapB {kind = default, con = default, exp = exp, decl = default, bind = bind}
-
         fun doDecl (d, (st : state, changed)) =
             let
                 (*val befor = Time.now ()*)
@@ -333,17 +453,53 @@
 
                 (*val () = Print.prefaces "decl" [("d", CorePrint.p_decl CoreEnv.empty d)]*)
 
+                val () = mayNotSpec := SS.empty
+
                 val (d', st) =
                     if isPoly d then
                         (d, st)
                     else
-                        (mayNotSpec := SS.empty(*(case #1 d of
-                                            DValRec vis => foldl (fn ((x, _, _, _, _), mns) =>
-                                                                     SS.add (mns, x)) SS.empty vis
-                                          | DVal (x, _, _, _, _) => SS.singleton x
-                                          | _ => SS.empty)*);
-                         specDecl [] st d
-                         before mayNotSpec := SS.empty)
+                        case #1 d of
+                            DVal (x, n, t, e, s) =>
+                            let
+                                val (e, st) = exp ([], e, st)
+                            in
+                                ((DVal (x, n, t, e, s), #2 d), st)
+                            end
+                          | DValRec vis =>
+                            let
+                                val (vis, st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) =>
+                                                                      let
+                                                                          val (e, st) = exp ([], e, st)
+                                                                      in
+                                                                          ((x, n, t, e, s), st)
+                                                                      end) st vis
+                            in
+                                ((DValRec vis, #2 d), st)
+                            end
+                          | DTable (s, n, t, s1, e1, t1, e2, t2) =>
+                            let
+                                val (e1, st) = exp ([], e1, st)
+                                val (e2, st) = exp ([], e2, st)
+                            in
+                                ((DTable (s, n, t, s1, e1, t2, e2, t2), #2 d), st)
+                            end
+                          | DView (x, n, s, e, t) =>
+                            let
+                                val (e, st) = exp ([], e, st)
+                            in
+                                ((DView (x, n, s, e, t), #2 d), st)
+                            end
+                          | DTask (e1, e2) =>
+                            let
+                                val (e1, st) = exp ([], e1, st)
+                                val (e2, st) = exp ([], e2, st)
+                            in
+                                ((DTask (e1, e2), #2 d), st)
+                            end
+                          | _ => (d, st)
+
+                val () = mayNotSpec := SS.empty
 
                 (*val () = print "/decl\n"*)
 
@@ -380,21 +536,20 @@
 
         val (ds, (st, changed)) = ListUtil.foldlMapConcat doDecl
                                                             ({maxName = U.File.maxName file + 1,
-                                                              funcs = IM.empty,
+                                                              funcs = funcs,
                                                               decls = [],
                                                               specialized = specialized},
                                                              false)
                                                             file
     in
-        (changed, ds, #specialized st)
+        (changed, ds, #funcs st, #specialized st)
     end
 
-fun specializeL specialized file =
+fun specializeL (funcs, specialized) file =
     let
         val file = ReduceLocal.reduce file
-        (*val () = Print.prefaces "Intermediate" [("file", CorePrint.p_file CoreEnv.empty file)]*)
         (*val file = ReduceLocal.reduce file*)
-        val (changed, file, specialized) = specialize' specialized file
+        val (changed, file, funcs, specialized) = specialize' (funcs, specialized) file
         (*val file = ReduceLocal.reduce file
         val file = CoreUntangle.untangle file
         val file = Shake.shake file*)
@@ -409,12 +564,13 @@
                 val file = Shake.shake file
             in
                 (*print "Again!\n";*)
-                specializeL specialized file
+                (*Print.prefaces "Again" [("file", CorePrint.p_file CoreEnv.empty file)];*)
+                specializeL (funcs, specialized) file
             end
         else
             file
     end
 
-val specialize = specializeL IM.empty
+val specialize = specializeL (IM.empty, IS.empty)
 
 end