diff src/unpoly.sml @ 794:dc3fc3f3b834

Improving/reordering Unpoly and Especialize; pathmaps
author Adam Chlipala <adamc@hcoop.net>
date Thu, 14 May 2009 08:13:54 -0400
parents 2d64457eedb1
children 6271f0e3c272
line wrap: on
line diff
--- a/src/unpoly.sml	Tue May 12 20:15:11 2009 -0400
+++ b/src/unpoly.sml	Thu May 14 08:13:54 2009 -0400
@@ -72,8 +72,19 @@
                             end
                           | _ => e}
 
+structure M = BinaryMapFn(struct
+                          type ord_key = con list
+                          val compare = Order.joinL U.Con.compare
+                          end)
+
+type func = {
+     kinds : kind list,
+     defs : (string * int * con * exp * string) list,
+     replacements : int M.map
+}
+
 type state = {
-     funcs : (kind list * (string * int * con * exp * string) list) IM.map,
+     funcs : func IM.map,
      decls : decl list,
      nextName : int
 }
@@ -86,8 +97,6 @@
     case e of
         ECApp _ =>
         let
-            (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty (e, ErrorMsg.dummySpan))]*)
-
             fun unravel (e, cargs) =
                 case e of
                     ECApp ((e, _), c) => unravel (e, c :: cargs)
@@ -102,72 +111,101 @@
                 else
                     case IM.find (#funcs st, n) of
                         NONE => (e, st)
-                      | SOME (ks, vis) =>
-                        let
-                            val (vis, nextName) = ListUtil.foldlMap
-                                                      (fn ((x, n, t, e, s), nextName) =>
-                                                          ((x, nextName, n, t, e, s), nextName + 1))
-                                                      (#nextName st) vis
+                      | SOME {kinds = ks, defs = vis, replacements} =>
+                        case M.find (replacements, cargs) of
+                            SOME n => (ENamed n, st)
+                          | NONE =>
+                            let
+                                val old_vis = vis
+                                val (vis, (thisName, nextName)) =
+                                    ListUtil.foldlMap
+                                        (fn ((x, n', t, e, s), (thisName, nextName)) =>
+                                            ((x, nextName, n', t, e, s),
+                                             (if n' = n then nextName else thisName,
+                                              nextName + 1)))
+                                        (0, #nextName st) vis
 
-                            fun specialize (x, n, n_old, t, e, s) =
-                                let
-                                    fun trim (t, e, cargs) =
-                                        case (t, e, cargs) of
-                                            ((TCFun (_, _, t), _),
-                                             (ECAbs (_, _, e), _),
-                                             carg :: cargs) =>
-                                            let
-                                                val t = subConInCon (length cargs, carg) t
-                                                val e = subConInExp (length cargs, carg) e
-                                            in
-                                                trim (t, e, cargs)
-                                            end
-                                          | (_, _, []) =>
-                                            let
-                                                val e = foldl (fn ((_, n, n_old, _, _, _), e) =>
-                                                                  unpolyNamed (n_old, ENamed n) e)
-                                                              e vis
-                                            in
-                                                SOME (t, e)
-                                            end
-                                          | _ => NONE
-                                in
-                                    (*Print.prefaces "specialize"
-                                                     [("t", CorePrint.p_con CoreEnv.empty t),
-                                                      ("e", CorePrint.p_exp CoreEnv.empty e),
-                                                      ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
-                                    Option.map (fn (t, e) => (x, n, n_old, t, e, s))
-                                               (trim (t, e, cargs))
-                                end
+                                fun specialize (x, n, n_old, t, e, s) =
+                                    let
+                                        fun trim (t, e, cargs) =
+                                            case (t, e, cargs) of
+                                                ((TCFun (_, _, t), _),
+                                                 (ECAbs (_, _, e), _),
+                                                 carg :: cargs) =>
+                                                let
+                                                    val t = subConInCon (length cargs, carg) t
+                                                    val e = subConInExp (length cargs, carg) e
+                                                in
+                                                    trim (t, e, cargs)
+                                                end
+                                              | (_, _, []) =>
+                                                (*let
+                                                    val e = foldl (fn ((_, n, n_old, _, _, _), e) =>
+                                                                      unpolyNamed (n_old, ENamed n) e)
+                                                                  e vis
+                                                in*)
+                                                    SOME (t, e)
+                                                (*end*)
+                                              | _ => NONE
+                                    in
+                                        (*Print.prefaces "specialize"
+                                                         [("t", CorePrint.p_con CoreEnv.empty t),
+                                                          ("e", CorePrint.p_exp CoreEnv.empty e),
+                                                          ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
+                                        Option.map (fn (t, e) => (x, n, n_old, t, e, s))
+                                                   (trim (t, e, cargs))
+                                    end
 
-                            val vis = List.map specialize vis
-                        in
-                            if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
-                                (e, st)
-                            else
-                                let
-                                    val vis = List.mapPartial (fn x => x) vis
-                                    val vis = map (fn (x, n, n_old, t, e, s) =>
-                                                      (x ^ "_unpoly", n, n_old, t, e, s)) vis
-                                    val vis' = map (fn (x, n, _, t, e, s) =>
-                                                       (x, n, t, e, s)) vis
+                                val vis = List.map specialize vis
+                            in
+                                if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
+                                    (e, st)
+                                else
+                                    let
+                                        val vis = List.mapPartial (fn x => x) vis
 
-                                    val ks' = List.drop (ks, length cargs)
-                                in
-                                    case List.find (fn (_, _, n_old, _, _, _) => n_old = n) vis of
-                                        NONE => raise Fail "Unpoly: Inconsistent 'val rec' record"
-                                      | SOME (_, n, _, _, _, _) =>
-                                        (ENamed n,
-                                         {funcs = foldl (fn (vi, funcs) =>
-                                                            IM.insert (funcs, #2 vi, (ks', vis')))
-                                                        (#funcs st) vis',
+                                        val vis = map (fn (x, n, n_old, t, e, s) =>
+                                                          (x ^ "_unpoly", n, n_old, t, e, s)) vis
+                                        val vis' = map (fn (x, n, _, t, e, s) =>
+                                                           (x, n, t, e, s)) vis
+
+                                        val funcs = IM.insert (#funcs st, n,
+                                                               {kinds = ks,
+                                                                defs = old_vis,
+                                                                replacements = M.insert (replacements,
+                                                                                         cargs,
+                                                                                         thisName)})
+
+                                        val ks' = List.drop (ks, length cargs)
+
+                                        val st = {funcs = foldl (fn (vi, funcs) =>
+                                                                    IM.insert (funcs, #2 vi,
+                                                                               {kinds = ks',
+                                                                                defs = vis',
+                                                                                replacements = M.empty}))
+                                                                funcs vis',
+                                                  decls = #decls st,
+                                                  nextName = nextName}
+
+                                        val (vis', st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) =>
+                                                                               let
+                                                                                   val (e, st) = polyExp (e, st)
+                                                                               in
+                                                                                   ((x, n, t, e, s), st)
+                                                                               end)
+                                                                           st vis'
+                                    in
+                                        (ENamed thisName,
+                                         {funcs = #funcs st,
                                           decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st,
-                                          nextName = nextName})
-                                end
-                        end
+                                          nextName = #nextName st})
+                                    end
+                            end
         end
       | _ => (e, st)
 
+and polyExp (x, st) = U.Exp.foldMap {kind = kind, con = con, exp = exp} st x
+
 fun decl (d, st : state) =
     case d of
         DValRec (vis as ((x, n, t, e, s) :: rest)) =>
@@ -232,7 +270,9 @@
                         (d, st)
                     else
                         (d, {funcs = foldl (fn (vi, funcs) =>
-                                               IM.insert (funcs, #2 vi, (cargs, vis)))
+                                               IM.insert (funcs, #2 vi, {kinds = cargs,
+                                                                         defs = vis,
+                                                                         replacements = M.empty}))
                                            (#funcs st) vis,
                              decls = #decls st,
                              nextName = #nextName st})