diff src/monoize.sml @ 1287:5137b0537c92

Polymorphic variants
author Adam Chlipala <adam@chlipala.net>
date Thu, 19 Aug 2010 17:28:52 -0400
parents a9a500d22ebc
children fc7ecf8883b1
line wrap: on
line diff
--- a/src/monoize.sml	Tue Aug 10 16:02:55 2010 -0400
+++ b/src/monoize.sml	Thu Aug 19 17:28:52 2010 -0400
@@ -36,11 +36,47 @@
 structure IM = IntBinaryMap
 structure IS = IntBinarySet
 
-structure SS = BinarySetFn(struct
-                           type ord_key = string
-                           val compare = String.compare
+structure SK = struct
+type ord_key = string
+val compare = String.compare
+end
+
+structure SS = BinarySetFn(SK)
+structure SM = BinaryMapFn(SK)
+
+structure RM = BinaryMapFn(struct
+                           type ord_key = (string * L'.typ) list
+                           fun compare (r1, r2) = MonoUtil.Typ.compare ((L'.TRecord r1, E.dummySpan),
+                                                                        (L'.TRecord r2, E.dummySpan))
                            end)
 
+val nextPvar = ref 0
+val pvars = ref (RM.empty : (int * (string * int * L'.typ) list) RM.map)
+val pvarDefs = ref ([] : L'.decl list)
+
+fun choosePvar () =
+    let
+        val n = !nextPvar
+    in
+        nextPvar := n + 1;
+        n
+    end
+
+fun pvar (r, loc) =
+    case RM.find (!pvars, r) of
+        NONE =>
+        let
+            val n = choosePvar ()
+            val fs = map (fn (x, t) => (x, choosePvar (), t)) r
+            val fs' = foldl (fn ((x, n, _), fs') => SM.insert (fs', x, n)) SM.empty fs
+        in
+            pvars := RM.insert (!pvars, r, (n, fs));
+            pvarDefs := (L'.DDatatype [("$poly" ^ Int.toString n, n, map (fn (x, n, t) => (x, n, SOME t)) fs)], loc) 
+                        :: !pvarDefs;
+            (n, fs)
+        end
+      | SOME v => v
+
 val singletons = SS.addList (SS.empty,
                              ["link",
                               "br",
@@ -120,6 +156,16 @@
                   | L.CApp ((L.CFfi ("Basis", "list"), _), t) =>
                     (L'.TList (mt env dtmap t), loc)
 
+                  | L.CApp ((L.CFfi ("Basis", "variant"), _), (L.CRecord ((L.KType, _), xts), _)) =>
+                    let
+                        val xts = map (fn (x, t) => (monoName env x, mt env dtmap t)) xts
+                        val xts = ListMergeSort.sort (fn ((x, _), (y, _)) => String.compare (x, y) = GREATER) xts
+                        val (n, cs) = pvar (xts, loc)
+                        val cs = map (fn (x, n, t) => (x, n, SOME t)) cs
+                    in
+                        (L'.TDatatype (n, ref (ElabUtil.classifyDatatype cs, cs)), loc)
+                    end
+
                   | L.CApp ((L.CFfi ("Basis", "monad"), _), _) =>
                     (L'.TRecord [], loc)
 
@@ -348,8 +394,24 @@
     decls = []
 }
 
+fun chooseNext count =
+    let
+        val n = !nextPvar
+    in
+        if count < n then
+            (count, count+1)
+        else
+            (nextPvar := n + 1;
+             (n, n+1))
+    end
+
 fun enter ({count, map, listMap, ...} : t) = {count = count, map = map, listMap = listMap, decls = []}
-fun freshName {count, map, listMap, decls} = (count, {count = count + 1, map = map, listMap = listMap, decls = decls})
+fun freshName {count, map, listMap, decls} =
+    let
+        val (next, count) = chooseNext count
+    in
+        (next, {count = count , map = map, listMap = listMap, decls = decls})
+    end
 fun decls ({decls, ...} : t) = decls
 
 fun lookup (t as {count, map, listMap, decls}) k n thunk =
@@ -752,6 +814,53 @@
             end
           | L.ECon _ => poly ()
 
+          | L.ECApp (
+            (L.ECApp (
+             (L.ECApp ((L.EFfi ("Basis", "make"), _), (L.CName nm, _)), _),
+             t), _),
+            (L.CRecord (_, xts), _)) =>
+            let
+                val t = monoType env t
+                val xts = map (fn (x, t) => (monoName env x, monoType env t)) xts
+                val xts = (nm, t) :: xts
+                val xts = ListMergeSort.sort (fn ((x, _), (y, _)) => String.compare (x, y) = GREATER) xts
+                val (n, cs) = pvar (xts, loc)
+                val cs' = map (fn (x, n, t) => (x, n, SOME t)) cs
+                val cl = ElabUtil.classifyDatatype cs'
+            in
+                case List.find (fn (nm', _, _) => nm' = nm) cs of
+                    NONE => raise Fail "Monoize: Polymorphic variant tag mismatch for 'make'"
+                  | SOME (_, n', _) => ((L'.EAbs ("x", t, (L'.TDatatype (n, ref (cl, cs')), loc),
+                                                  (L'.ECon (cl, L'.PConVar n', SOME (L'.ERel 0, loc)), loc)), loc),
+                                        fm)
+            end
+
+          | L.ECApp (
+            (L.ECApp ((L.EFfi ("Basis", "match"), _), (L.CRecord (_, xts), _)), _),
+            t) =>
+            let
+                val t = monoType env t
+                val xts = map (fn (x, t) => (monoName env x, monoType env t)) xts
+                val xts = ListMergeSort.sort (fn ((x, _), (y, _)) => String.compare (x, y) = GREATER) xts
+                val (n, cs) = pvar (xts, loc)
+                val cs' = map (fn (x, n, t) => (x, n, SOME t)) cs
+                val cl = ElabUtil.classifyDatatype cs'
+                val fs = (L'.TRecord (map (fn (x, t') => (x, (L'.TFun (t', t), loc))) xts), loc)
+                val dt = (L'.TDatatype (n, ref (cl, cs')), loc)
+            in
+                ((L'.EAbs ("v",
+                           dt,
+                           (L'.TFun (fs, t), loc),
+                           (L'.EAbs ("fs", fs, t,
+                                     (L'.ECase ((L'.ERel 1, loc),
+                                                map (fn (x, n', t') =>
+                                                        ((L'.PCon (cl, L'.PConVar n', SOME (L'.PVar ("x", t'), loc)), loc),
+                                                         (L'.EApp ((L'.EField ((L'.ERel 1, loc), x), loc),
+                                                                   (L'.ERel 0, loc)), loc))) cs,
+                                                {disc = dt, result = t}), loc)), loc)), loc),
+                 fm)
+            end
+
           | L.ECApp ((L.EFfi ("Basis", "eq"), _), t) =>
             let
                 val t = monoType env t
@@ -3821,6 +3930,8 @@
 
 fun monoize env file =
     let
+        val () = pvars := RM.empty
+
         (* Calculate which exported functions need cookie signature protection *)
         val rcook = foldl (fn ((d, _), rcook) =>
                               case d of
@@ -3958,6 +4069,9 @@
                             | _ => e) e file
             end
 
+        val mname = CoreUtil.File.maxName file + 1
+        val () = nextPvar := mname
+
         val (_, _, ds) = List.foldl (fn (d, (env, fm, ds)) =>
                                         case #1 d of
                                             L.DDatabase s =>
@@ -3984,14 +4098,17 @@
                                                                    :: ds)
                                             end
                                           | _ =>
-                                            case monoDecl (env, fm) d of
-                                                NONE => (env, fm, ds)
-                                              | SOME (env, fm, ds') =>
-                                                (env,
-                                                 Fm.enter fm,
-                                                 ds' @ Fm.decls fm @ ds))
-                                    (env, Fm.empty (CoreUtil.File.maxName file + 1), []) file
+                                            (pvarDefs := [];
+                                             case monoDecl (env, fm) d of
+                                                 NONE => (env, fm, ds)
+                                               | SOME (env, fm, ds') =>
+                                                 (env,
+                                                  Fm.enter fm,
+                                                  ds' @ Fm.decls fm @ !pvarDefs @ ds)))
+                                    (env, Fm.empty mname, []) file
     in
+        pvars := RM.empty;
+        pvarDefs := [];
         rev ds
     end