diff src/unpoly.sml @ 1276:5b5c0b552f59

Another run of Specialize, using ReduceLocal on datatype parameters
author Adam Chlipala <adamc@hcoop.net>
date Sat, 05 Jun 2010 09:42:37 -0400
parents 338be96f8533
children
line wrap: on
line diff
--- a/src/unpoly.sml	Thu Jun 03 14:44:08 2010 -0400
+++ b/src/unpoly.sml	Sat Jun 05 09:42:37 2010 -0400
@@ -116,97 +116,102 @@
                     case IM.find (#funcs st, n) of
                         NONE => (e, st)
                       | SOME {kinds = ks, defs = vis, replacements} =>
-                        case M.find (replacements, cargs) of
-                            SOME n => (ENamed n, st)
-                          | NONE =>
-                            let
-                                val old_vis = vis
-                                val (vis, (thisName, nextName)) =
-                                    ListUtil.foldlMap
-                                        (fn ((x, n', t, e, s), (thisName, nextName)) =>
-                                            ((x, nextName, n', t, e, s),
-                                             (if n' = n then nextName else thisName,
-                                              nextName + 1)))
-                                        (0, #nextName st) vis
+                        let
+                            val cargs = map ReduceLocal.reduceCon cargs
+                        in
+                            case M.find (replacements, cargs) of
+                                SOME n => (ENamed n, st)
+                              | NONE =>
+                                let
+                                    val old_vis = vis
+                                    val (vis, (thisName, nextName)) =
+                                        ListUtil.foldlMap
+                                            (fn ((x, n', t, e, s), (thisName, nextName)) =>
+                                                ((x, nextName, n', t, e, s),
+                                                 (if n' = n then nextName else thisName,
+                                                  nextName + 1)))
+                                            (0, #nextName st) vis
 
-                                fun specialize (x, n, n_old, t, e, s) =
-                                    let
-                                        fun trim (t, e, cargs) =
-                                            case (t, e, cargs) of
-                                                ((TCFun (_, _, t), _),
-                                                 (ECAbs (_, _, e), _),
-                                                 carg :: cargs) =>
-                                                let
-                                                    val t = subConInCon (length cargs, carg) t
-                                                    val e = subConInExp (length cargs, carg) e
-                                                in
-                                                    trim (t, e, cargs)
-                                                end
-                                              | (_, _, []) => SOME (t, e)
-                                              | _ => NONE
-                                    in
-                                        (*Print.prefaces "specialize"
-                                                       [("n", Print.PD.string (Int.toString n)),
-                                                        ("nold", Print.PD.string (Int.toString n_old)),
-                                                        ("t", CorePrint.p_con CoreEnv.empty t),
-                                                        ("e", CorePrint.p_exp CoreEnv.empty e),
-                                                        ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
-                                        Option.map (fn (t, e) => (x, n, n_old, t, e, s))
-                                                   (trim (t, e, cargs))
-                                    end
+                                    fun specialize (x, n, n_old, t, e, s) =
+                                        let
+                                            fun trim (t, e, cargs) =
+                                                case (t, e, cargs) of
+                                                    ((TCFun (_, _, t), _),
+                                                     (ECAbs (_, _, e), _),
+                                                     carg :: cargs) =>
+                                                    let
+                                                        val t = subConInCon (length cargs, carg) t
+                                                        val e = subConInExp (length cargs, carg) e
+                                                    in
+                                                        trim (t, e, cargs)
+                                                    end
+                                                  | (_, _, []) => SOME (t, e)
+                                                  | _ => NONE
+                                        in
+                                            (*Print.prefaces "specialize"
+                                                             [("n", Print.PD.string (Int.toString n)),
+                                                              ("nold", Print.PD.string (Int.toString n_old)),
+                                                              ("t", CorePrint.p_con CoreEnv.empty t),
+                                                              ("e", CorePrint.p_exp CoreEnv.empty e),
+                                                              ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
+                                            Option.map (fn (t, e) => (x, n, n_old, t, e, s))
+                                                       (trim (t, e, cargs))
+                                        end
 
-                                val vis = List.map specialize vis
-                            in
-                                if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
-                                    (e, st)
-                                else
-                                    let
-                                        val vis = List.mapPartial (fn x => x) vis
+                                    val vis = List.map specialize vis
+                                in
+                                    if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
+                                        (e, st)
+                                    else
+                                        let
+                                            val vis = List.mapPartial (fn x => x) vis
 
-                                        val vis = map (fn (x, n, n_old, t, e, s) =>
-                                                          (x ^ "_unpoly", n, n_old, t, e, s)) vis
-                                        val vis' = map (fn (x, n, _, t, e, s) =>
-                                                           (x, n, t, e, s)) vis
+                                            val vis = map (fn (x, n, n_old, t, e, s) =>
+                                                              (x ^ "_unpoly", n, n_old, t, e, s)) vis
+                                            val vis' = map (fn (x, n, _, t, e, s) =>
+                                                               (x, n, t, e, s)) vis
 
-                                        val funcs = foldl (fn ((_, n, n_old, _, _, _), funcs) =>
-                                                              let
-                                                                  val replacements = case IM.find (funcs, n_old) of
-                                                                                         NONE => M.empty
-                                                                                       | SOME {replacements = r, ...} => r
-                                                              in
-                                                                  IM.insert (funcs, n_old,
-                                                                             {kinds = ks,
-                                                                              defs = old_vis,
-                                                                              replacements = M.insert (replacements,
-                                                                                                       cargs,
-                                                                                                       n)})
-                                                              end) (#funcs st) vis
+                                            val funcs = foldl (fn ((_, n, n_old, _, _, _), funcs) =>
+                                                                  let
+                                                                      val replacements = case IM.find (funcs, n_old) of
+                                                                                             NONE => M.empty
+                                                                                           | SOME {replacements = r,
+                                                                                                   ...} => r
+                                                                  in
+                                                                      IM.insert (funcs, n_old,
+                                                                                 {kinds = ks,
+                                                                                  defs = old_vis,
+                                                                                  replacements = M.insert (replacements,
+                                                                                                           cargs,
+                                                                                                           n)})
+                                                                  end) (#funcs st) vis
 
-                                        val ks' = List.drop (ks, length cargs)
+                                            val ks' = List.drop (ks, length cargs)
 
-                                        val st = {funcs = foldl (fn (vi, funcs) =>
-                                                                    IM.insert (funcs, #2 vi,
-                                                                               {kinds = ks',
-                                                                                defs = vis',
-                                                                                replacements = M.empty}))
-                                                                funcs vis',
-                                                  decls = #decls st,
-                                                  nextName = nextName}
+                                            val st = {funcs = foldl (fn (vi, funcs) =>
+                                                                        IM.insert (funcs, #2 vi,
+                                                                                   {kinds = ks',
+                                                                                    defs = vis',
+                                                                                    replacements = M.empty}))
+                                                                    funcs vis',
+                                                      decls = #decls st,
+                                                      nextName = nextName}
 
-                                        val (vis', st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) =>
-                                                                               let
-                                                                                   val (e, st) = polyExp (e, st)
-                                                                               in
-                                                                                   ((x, n, t, e, s), st)
-                                                                               end)
-                                                                           st vis'
-                                    in
-                                        (ENamed thisName,
-                                         {funcs = #funcs st,
-                                          decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st,
-                                          nextName = #nextName st})
-                                    end
-                            end
+                                            val (vis', st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) =>
+                                                                                   let
+                                                                                       val (e, st) = polyExp (e, st)
+                                                                                   in
+                                                                                       ((x, n, t, e, s), st)
+                                                                                   end)
+                                                                               st vis'
+                                        in
+                                            (ENamed thisName,
+                                             {funcs = #funcs st,
+                                              decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st,
+                                              nextName = #nextName st})
+                                        end
+                                end
+                        end
         end
       | _ => (e, st)