changeset 1276:5b5c0b552f59

Another run of Specialize, using ReduceLocal on datatype parameters
author Adam Chlipala <adamc@hcoop.net>
date Sat, 05 Jun 2010 09:42:37 -0400
parents 74150edf1134
children 1e6a4f9d3e4a
files src/compiler.sig src/compiler.sml src/reduce_local.sig src/reduce_local.sml src/sources src/specialize.sml src/unpoly.sml
diffstat 7 files changed, 150 insertions(+), 136 deletions(-) [+]
line wrap: on
line diff
--- a/src/compiler.sig	Thu Jun 03 14:44:08 2010 -0400
+++ b/src/compiler.sig	Sat Jun 05 09:42:37 2010 -0400
@@ -134,6 +134,7 @@
     val toShake4 : (string, Core.file) transform
     val toEspecialize2 : (string, Core.file) transform
     val toShake4' : (string, Core.file) transform
+    val toSpecialize2 : (string, Core.file) transform
     val toUnpoly2 : (string, Core.file) transform
     val toShake4'' : (string, Core.file) transform
     val toEspecialize3 : (string, Core.file) transform
--- a/src/compiler.sml	Thu Jun 03 14:44:08 2010 -0400
+++ b/src/compiler.sml	Sat Jun 05 09:42:37 2010 -0400
@@ -1015,7 +1015,8 @@
 val toEspecialize2 = transform especialize "especialize2" o toShake4
 val toShake4' = transform shake "shake4'" o toEspecialize2
 val toUnpoly2 = transform unpoly "unpoly2" o toShake4'
-val toShake4'' = transform shake "shake4'" o toUnpoly2
+val toSpecialize2 = transform specialize "specialize2" o toUnpoly2
+val toShake4'' = transform shake "shake4'" o toSpecialize2
 val toEspecialize3 = transform especialize "especialize3" o toShake4''
 
 val toReduce2 = transform reduce "reduce2" o toEspecialize3
--- a/src/reduce_local.sig	Thu Jun 03 14:44:08 2010 -0400
+++ b/src/reduce_local.sig	Sat Jun 05 09:42:37 2010 -0400
@@ -1,4 +1,4 @@
-(* Copyright (c) 2008, Adam Chlipala
+(* Copyright (c) 2008-2010, Adam Chlipala
  * All rights reserved.
  *
  * Redistribution and use in source and binary forms, with or without
@@ -31,5 +31,6 @@
 
     val reduce : Core.file -> Core.file
     val reduceExp : Core.exp -> Core.exp
+    val reduceCon : Core.con -> Core.con
     
 end
--- a/src/reduce_local.sml	Thu Jun 03 14:44:08 2010 -0400
+++ b/src/reduce_local.sml	Sat Jun 05 09:42:37 2010 -0400
@@ -383,5 +383,6 @@
     end
 
 val reduceExp = exp []
+val reduceCon = con []
 
 end
--- a/src/sources	Thu Jun 03 14:44:08 2010 -0400
+++ b/src/sources	Sat Jun 05 09:42:37 2010 -0400
@@ -113,15 +113,15 @@
 shake.sig
 shake.sml
 
+reduce_local.sig
+reduce_local.sml
+
 unpoly.sig
 unpoly.sml
 
 specialize.sig
 specialize.sml
 
-reduce_local.sig
-reduce_local.sml
-
 core_untangle.sig
 core_untangle.sml
 
--- a/src/specialize.sml	Thu Jun 03 14:44:08 2010 -0400
+++ b/src/specialize.sml	Sat Jun 05 09:42:37 2010 -0400
@@ -1,4 +1,4 @@
-(* Copyright (c) 2008, Adam Chlipala
+(* Copyright (c) 2008-2010, Adam Chlipala
  * All rights reserved.
  *
  * Redistribution and use in source and binary forms, with or without
@@ -73,58 +73,63 @@
                                       | _ => false}
 
 fun considerSpecialization (st : state, n, args, dt : datatyp) =
-    case CM.find (#specializations dt, args) of
-        SOME dt' => (#name dt', #constructors dt', st)
-      | NONE =>
-        let
-            (*val () = Print.prefaces "Args" [("args", Print.p_list (CorePrint.p_con CoreEnv.empty) args)]*)
+    let
+        val args = map ReduceLocal.reduceCon args
+    in
+        case CM.find (#specializations dt, args) of
+            SOME dt' => (#name dt', #constructors dt', st)
+          | NONE =>
+            let
+                (*val () = Print.prefaces "Args" [("n", Print.PD.string (Int.toString n)),
+                                                ("args", Print.p_list (CorePrint.p_con CoreEnv.empty) args)]*)
 
-            val n' = #count st
+                val n' = #count st
 
-            val nxs = length args - 1
-            fun sub t = ListUtil.foldli (fn (i, arg, t) =>
-                                            subConInCon (nxs - i, arg) t) t args
+                val nxs = length args - 1
+                fun sub t = ListUtil.foldli (fn (i, arg, t) =>
+                                                subConInCon (nxs - i, arg) t) t args
 
-            val (cons, (count, cmap)) =
-                ListUtil.foldlMap (fn ((x, n, to), (count, cmap)) =>
-                                      let
-                                          val to = Option.map sub to
-                                      in
-                                          ((x, count, to),
-                                           (count + 1,
-                                            IM.insert (cmap, n, count)))
-                                      end) (n' + 1, IM.empty) (#constructors dt)
+                val (cons, (count, cmap)) =
+                    ListUtil.foldlMap (fn ((x, n, to), (count, cmap)) =>
+                                          let
+                                              val to = Option.map sub to
+                                          in
+                                              ((x, count, to),
+                                               (count + 1,
+                                                IM.insert (cmap, n, count)))
+                                          end) (n' + 1, IM.empty) (#constructors dt)
 
-            val st = {count = count,
-                      datatypes = IM.insert (#datatypes st, n,
-                                             {name = #name dt,
-                                              params = #params dt,
-                                              constructors = #constructors dt,
-                                              specializations = CM.insert (#specializations dt,
-                                                                           args,
-                                                                           {name = n',
-                                                                            constructors = cmap})}),
-                      constructors = #constructors st,
-                      decls = #decls st}
+                val st = {count = count,
+                          datatypes = IM.insert (#datatypes st, n,
+                                                 {name = #name dt,
+                                                  params = #params dt,
+                                                  constructors = #constructors dt,
+                                                  specializations = CM.insert (#specializations dt,
+                                                                               args,
+                                                                               {name = n',
+                                                                                constructors = cmap})}),
+                          constructors = #constructors st,
+                          decls = #decls st}
 
-            val (cons, st) = ListUtil.foldlMap (fn ((x, n, NONE), st) => ((x, n, NONE), st)
-                                                 | ((x, n, SOME t), st) =>
-                                                   let
-                                                       val (t, st) = specCon st t
-                                                   in
-                                                       ((x, n, SOME t), st)
-                                                   end) st cons
+                val (cons, st) = ListUtil.foldlMap (fn ((x, n, NONE), st) => ((x, n, NONE), st)
+                                                     | ((x, n, SOME t), st) =>
+                                                       let
+                                                           val (t, st) = specCon st t
+                                                       in
+                                                           ((x, n, SOME t), st)
+                                                       end) st cons
 
-            val dt = (#name dt ^ "_s",
-                      n',
-                      [],
-                      cons)
-        in
-            (n', cmap, {count = #count st,
-                        datatypes = #datatypes st,
-                        constructors = #constructors st,
-                        decls = dt :: #decls st})
-        end
+                val dt = (#name dt ^ "_s",
+                          n',
+                          [],
+                          cons)
+            in
+                (n', cmap, {count = #count st,
+                            datatypes = #datatypes st,
+                            constructors = #constructors st,
+                            decls = dt :: #decls st})
+            end
+    end
 
 and con (c, st : state) =
     let
--- a/src/unpoly.sml	Thu Jun 03 14:44:08 2010 -0400
+++ b/src/unpoly.sml	Sat Jun 05 09:42:37 2010 -0400
@@ -116,97 +116,102 @@
                     case IM.find (#funcs st, n) of
                         NONE => (e, st)
                       | SOME {kinds = ks, defs = vis, replacements} =>
-                        case M.find (replacements, cargs) of
-                            SOME n => (ENamed n, st)
-                          | NONE =>
-                            let
-                                val old_vis = vis
-                                val (vis, (thisName, nextName)) =
-                                    ListUtil.foldlMap
-                                        (fn ((x, n', t, e, s), (thisName, nextName)) =>
-                                            ((x, nextName, n', t, e, s),
-                                             (if n' = n then nextName else thisName,
-                                              nextName + 1)))
-                                        (0, #nextName st) vis
+                        let
+                            val cargs = map ReduceLocal.reduceCon cargs
+                        in
+                            case M.find (replacements, cargs) of
+                                SOME n => (ENamed n, st)
+                              | NONE =>
+                                let
+                                    val old_vis = vis
+                                    val (vis, (thisName, nextName)) =
+                                        ListUtil.foldlMap
+                                            (fn ((x, n', t, e, s), (thisName, nextName)) =>
+                                                ((x, nextName, n', t, e, s),
+                                                 (if n' = n then nextName else thisName,
+                                                  nextName + 1)))
+                                            (0, #nextName st) vis
 
-                                fun specialize (x, n, n_old, t, e, s) =
-                                    let
-                                        fun trim (t, e, cargs) =
-                                            case (t, e, cargs) of
-                                                ((TCFun (_, _, t), _),
-                                                 (ECAbs (_, _, e), _),
-                                                 carg :: cargs) =>
-                                                let
-                                                    val t = subConInCon (length cargs, carg) t
-                                                    val e = subConInExp (length cargs, carg) e
-                                                in
-                                                    trim (t, e, cargs)
-                                                end
-                                              | (_, _, []) => SOME (t, e)
-                                              | _ => NONE
-                                    in
-                                        (*Print.prefaces "specialize"
-                                                       [("n", Print.PD.string (Int.toString n)),
-                                                        ("nold", Print.PD.string (Int.toString n_old)),
-                                                        ("t", CorePrint.p_con CoreEnv.empty t),
-                                                        ("e", CorePrint.p_exp CoreEnv.empty e),
-                                                        ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
-                                        Option.map (fn (t, e) => (x, n, n_old, t, e, s))
-                                                   (trim (t, e, cargs))
-                                    end
+                                    fun specialize (x, n, n_old, t, e, s) =
+                                        let
+                                            fun trim (t, e, cargs) =
+                                                case (t, e, cargs) of
+                                                    ((TCFun (_, _, t), _),
+                                                     (ECAbs (_, _, e), _),
+                                                     carg :: cargs) =>
+                                                    let
+                                                        val t = subConInCon (length cargs, carg) t
+                                                        val e = subConInExp (length cargs, carg) e
+                                                    in
+                                                        trim (t, e, cargs)
+                                                    end
+                                                  | (_, _, []) => SOME (t, e)
+                                                  | _ => NONE
+                                        in
+                                            (*Print.prefaces "specialize"
+                                                             [("n", Print.PD.string (Int.toString n)),
+                                                              ("nold", Print.PD.string (Int.toString n_old)),
+                                                              ("t", CorePrint.p_con CoreEnv.empty t),
+                                                              ("e", CorePrint.p_exp CoreEnv.empty e),
+                                                              ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
+                                            Option.map (fn (t, e) => (x, n, n_old, t, e, s))
+                                                       (trim (t, e, cargs))
+                                        end
 
-                                val vis = List.map specialize vis
-                            in
-                                if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
-                                    (e, st)
-                                else
-                                    let
-                                        val vis = List.mapPartial (fn x => x) vis
+                                    val vis = List.map specialize vis
+                                in
+                                    if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
+                                        (e, st)
+                                    else
+                                        let
+                                            val vis = List.mapPartial (fn x => x) vis
 
-                                        val vis = map (fn (x, n, n_old, t, e, s) =>
-                                                          (x ^ "_unpoly", n, n_old, t, e, s)) vis
-                                        val vis' = map (fn (x, n, _, t, e, s) =>
-                                                           (x, n, t, e, s)) vis
+                                            val vis = map (fn (x, n, n_old, t, e, s) =>
+                                                              (x ^ "_unpoly", n, n_old, t, e, s)) vis
+                                            val vis' = map (fn (x, n, _, t, e, s) =>
+                                                               (x, n, t, e, s)) vis
 
-                                        val funcs = foldl (fn ((_, n, n_old, _, _, _), funcs) =>
-                                                              let
-                                                                  val replacements = case IM.find (funcs, n_old) of
-                                                                                         NONE => M.empty
-                                                                                       | SOME {replacements = r, ...} => r
-                                                              in
-                                                                  IM.insert (funcs, n_old,
-                                                                             {kinds = ks,
-                                                                              defs = old_vis,
-                                                                              replacements = M.insert (replacements,
-                                                                                                       cargs,
-                                                                                                       n)})
-                                                              end) (#funcs st) vis
+                                            val funcs = foldl (fn ((_, n, n_old, _, _, _), funcs) =>
+                                                                  let
+                                                                      val replacements = case IM.find (funcs, n_old) of
+                                                                                             NONE => M.empty
+                                                                                           | SOME {replacements = r,
+                                                                                                   ...} => r
+                                                                  in
+                                                                      IM.insert (funcs, n_old,
+                                                                                 {kinds = ks,
+                                                                                  defs = old_vis,
+                                                                                  replacements = M.insert (replacements,
+                                                                                                           cargs,
+                                                                                                           n)})
+                                                                  end) (#funcs st) vis
 
-                                        val ks' = List.drop (ks, length cargs)
+                                            val ks' = List.drop (ks, length cargs)
 
-                                        val st = {funcs = foldl (fn (vi, funcs) =>
-                                                                    IM.insert (funcs, #2 vi,
-                                                                               {kinds = ks',
-                                                                                defs = vis',
-                                                                                replacements = M.empty}))
-                                                                funcs vis',
-                                                  decls = #decls st,
-                                                  nextName = nextName}
+                                            val st = {funcs = foldl (fn (vi, funcs) =>
+                                                                        IM.insert (funcs, #2 vi,
+                                                                                   {kinds = ks',
+                                                                                    defs = vis',
+                                                                                    replacements = M.empty}))
+                                                                    funcs vis',
+                                                      decls = #decls st,
+                                                      nextName = nextName}
 
-                                        val (vis', st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) =>
-                                                                               let
-                                                                                   val (e, st) = polyExp (e, st)
-                                                                               in
-                                                                                   ((x, n, t, e, s), st)
-                                                                               end)
-                                                                           st vis'
-                                    in
-                                        (ENamed thisName,
-                                         {funcs = #funcs st,
-                                          decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st,
-                                          nextName = #nextName st})
-                                    end
-                            end
+                                            val (vis', st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) =>
+                                                                                   let
+                                                                                       val (e, st) = polyExp (e, st)
+                                                                                   in
+                                                                                       ((x, n, t, e, s), st)
+                                                                                   end)
+                                                                               st vis'
+                                        in
+                                            (ENamed thisName,
+                                             {funcs = #funcs st,
+                                              decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st,
+                                              nextName = #nextName st})
+                                        end
+                                end
+                        end
         end
       | _ => (e, st)