diff src/elaborate.sml @ 191:aa54250f58ac

Parametrized datatypes through explify
author Adam Chlipala <adamc@hcoop.net>
date Fri, 08 Aug 2008 10:28:32 -0400
parents 8e9f97508f0d
children df5fd8f6913a
line wrap: on
line diff
--- a/src/elaborate.sml	Thu Aug 07 13:09:26 2008 -0400
+++ b/src/elaborate.sml	Fri Aug 08 10:28:32 2008 -0400
@@ -933,19 +933,32 @@
         val pterror = (perror, terror)
         val rerror = (pterror, (env, bound))
 
-        fun pcon (pc, po, to, dn, dk) =
+        fun pcon (pc, po, xs, to, dn, dk) =
             case (po, to) of
                 (NONE, SOME _) => (expError env (PatHasNoArg loc);
                                    rerror)
               | (SOME _, NONE) => (expError env (PatHasArg loc);
                                    rerror)
-              | (NONE, NONE) => (((L'.PCon (dk, pc, NONE), loc), dn),
-                                 (env, bound))
+              | (NONE, NONE) =>
+                let
+                    val k = (L'.KType, loc)
+                    val unifs = map (fn _ => cunif (loc, k)) xs
+                    val dn = foldl (fn (u, dn) => (L'.CApp (dn, u), loc)) dn unifs
+                in
+                    (((L'.PCon (dk, pc, unifs, NONE), loc), dn),
+                     (env, bound))
+                end
               | (SOME p, SOME t) =>
                 let
                     val ((p', pt), (env, bound)) = elabPat (p, (env, denv, bound))
+
+                    val k = (L'.KType, loc)
+                    val unifs = map (fn _ => cunif (loc, k)) xs
+                    val t = ListUtil.foldli (fn (i, u, t) => subConInCon (i, u) t) t unifs
+                    val dn = foldl (fn (u, dn) => (L'.CApp (dn, u), loc)) dn unifs
                 in
-                    (((L'.PCon (dk, pc, SOME p'), loc), dn),
+                    ignore (checkPatCon (env, denv) p' pt t);
+                    (((L'.PCon (dk, pc, unifs, SOME p'), loc), dn),
                      (env, bound))
                 end
     in
@@ -969,7 +982,7 @@
             (case E.lookupConstructor env x of
                  NONE => (expError env (UnboundConstructor (loc, [], x));
                           rerror)
-               | SOME (dk, n, to, dn) => pcon (L'.PConVar n, po, to, (L'.CNamed dn, loc), dk))
+               | SOME (dk, n, xs, to, dn) => pcon (L'.PConVar n, po, xs, to, (L'.CNamed dn, loc), dk))
           | L.PCon (m1 :: ms, x, po) =>
             (case E.lookupStr env m1 of
                  NONE => (expError env (UnboundStrInExp (loc, m1));
@@ -985,7 +998,7 @@
                      case E.projectConstructor env {str = str, sgn = sgn, field = x} of
                          NONE => (expError env (UnboundConstructor (loc, m1 :: ms, x));
                                   rerror)
-                       | SOME (dk, _, to, dn) => pcon (L'.PConProj (n, ms, x), po, to, dn, dk)
+                       | SOME (dk, _, xs, to, dn) => pcon (L'.PConProj (n, ms, x), po, xs, to, dn, dk)
                  end)
 
           | L.PRecord (xps, flex) =>
@@ -1035,7 +1048,7 @@
                 in
                     case E.projectConstructor env {str = str, sgn = sgn, field = x} of
                         NONE => raise Fail "exhaustive: Can't project constructor"
-                      | SOME (_, n, _, _) => n
+                      | SOME (_, n, _, _, _) => n
                 end
 
         fun coverage (p, _) =
@@ -1043,8 +1056,8 @@
                 L'.PWild => Wild
               | L'.PVar _ => Wild
               | L'.PPrim _ => None
-              | L'.PCon (_, pc, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild))
-              | L'.PCon (_, pc, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p))
+              | L'.PCon (_, pc, _, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild))
+              | L'.PCon (_, pc, _, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p))
               | L'.PRecord xps => Record [foldl (fn ((x, p, _), fmap) =>
                                                     SM.insert (fmap, x, coverage p)) SM.empty xps]
 
@@ -1158,8 +1171,13 @@
                                               (total, gs' @ gs)
                                           end)
                               (true, gs) cons
+
+                    fun unapp t =
+                        case t of
+                            L'.CApp ((t, _), _) => unapp t
+                          | _ => t
                 in
-                    case t of
+                    case unapp t of
                         L'.CNamed n =>
                         let
                             val dt = E.lookupDatatype env n
@@ -1173,7 +1191,7 @@
                         in
                             case E.projectDatatype env {str = str, sgn = sgn, field = x} of
                                 NONE => raise Fail "isTotal: Can't project datatype"
-                              | SOME cons => dtype cons
+                              | SOME (_, cons) => dtype cons
                         end
                       | L'.CError => (true, gs)
                       | _ => raise Fail "isTotal: Not a datatype"
@@ -1206,7 +1224,11 @@
                  (expError env (UnboundExp (loc, s));
                   (eerror, cerror, []))
                | E.Rel (n, t) => ((L'.ERel n, loc), t, [])
-               | E.Named (n, t) => ((L'.ENamed n, loc), t, []))
+               | E.Named (n, t) =>
+                 if Char.isUpper (String.sub (s, 0)) then
+                     elabHead (env, denv) (L'.ENamed n, loc) t
+                 else
+                     ((L'.ENamed n, loc), t, []))
           | L.EVar (m1 :: ms, s) =>
             (case E.lookupStr env m1 of
                  NONE => (expError env (UnboundStrInExp (loc, m1));
@@ -1572,11 +1594,13 @@
             ([(L'.SgiCon (x, n, k', c'), loc)], (env', denv, gs' @ gs))
         end
 
-      | L.SgiDatatype (x, xcs) =>
+      | L.SgiDatatype (x, xs, xcs) =>
         let
             val k = (L'.KType, loc)
-            val (env, n) = E.pushCNamed env x k NONE
+            val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs
+            val (env, n) = E.pushCNamed env x k' NONE
             val t = (L'.CNamed n, loc)
+            val t = ListUtil.foldli (fn (i, _, t) => (L'.CApp (t, (L'.CRel i, loc)), loc)) t xs
 
             val (xcs, (used, env, gs)) =
                 ListUtil.foldlMap
@@ -1591,6 +1615,7 @@
                                                checkKind env t' tk k;
                                                (SOME t', (L'.TFun (t', t), loc), gs' @ gs)
                                            end
+                        val t = foldl (fn (x, t) => (L'.TCFun (L'.Implicit, x, k, t), loc)) t xs
 
                         val (env, n') = E.pushENamed env x t
                     in
@@ -1601,8 +1626,10 @@
                         ((x, n', to), (SS.add (used, x), env, gs'))
                     end)
                 (SS.empty, env, []) xcs
+
+            val env = E.pushDatatype env n xs xcs
         in
-            ([(L'.SgiDatatype (x, n, xcs), loc)], (env, denv, gs))
+            ([(L'.SgiDatatype (x, n, xs, xcs), loc)], (env, denv, gs))
         end
 
       | L.SgiDatatypeImp (_, [], _) => raise Fail "Empty SgiDatatypeImp"
@@ -1625,12 +1652,14 @@
                      (case E.projectDatatype env {sgn = sgn, str = str, field = s} of
                           NONE => (conError env (UnboundDatatype (loc, s));
                                    ([], (env, denv, gs)))
-                        | SOME xncs =>
+                        | SOME (xs, xncs) =>
                           let
                               val k = (L'.KType, loc)
+                              val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs
+
                               val t = (L'.CModProj (n, ms, s), loc)
-                              val (env, n') = E.pushCNamed env x k (SOME t)
-                              val env = E.pushDatatype env n' xncs
+                              val (env, n') = E.pushCNamed env x k' (SOME t)
+                              val env = E.pushDatatype env n' xs xncs
 
                               val t = (L'.CNamed n', loc)
                               val env = foldl (fn ((x, n, to), env) =>
@@ -1638,11 +1667,15 @@
                                                       val t = case to of
                                                                   NONE => t
                                                                 | SOME t' => (L'.TFun (t', t), loc)
+
+                                                      val t = foldr (fn (x, t) =>
+                                                                        (L'.TCFun (L'.Implicit, x, k, t), loc))
+                                                              t xs
                                                   in
                                                       E.pushENamedAs env x n t
                                                   end) env xncs
                           in
-                              ([(L'.SgiDatatypeImp (x, n', n, ms, s, xncs), loc)], (env, denv, gs))
+                              ([(L'.SgiDatatypeImp (x, n', n, ms, s, xs, xncs), loc)], (env, denv, gs))
                           end)
                    | _ => (strError env (NotDatatype loc);
                            ([], (env, denv, [])))
@@ -1720,7 +1753,7 @@
                                    else
                                        ();
                                    (SS.add (cons, x), vals, sgns, strs))
-                                | L'.SgiDatatype (x, _, xncs) =>
+                                | L'.SgiDatatype (x, _, _, xncs) =>
                                   let
                                       val vals = foldl (fn ((x, _, _), vals) =>
                                                            (if SS.member (vals, x) then
@@ -1736,7 +1769,7 @@
                                           ();
                                       (SS.add (cons, x), vals, sgns, strs)
                                   end
-                                | L'.SgiDatatypeImp (x, _, _, _, _, _) =>
+                                | L'.SgiDatatypeImp (x, _, _, _, _, _, _) =>
                                   (if SS.member (cons, x) then
                                        sgnError env (DuplicateCon (loc, x))
                                    else
@@ -1828,8 +1861,8 @@
       | L'.SgnConst sgis =>
         (L'.SgnConst (map (fn (L'.SgiConAbs (x, n, k), loc) =>
                               (L'.SgiCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc)
-                            | (L'.SgiDatatype (x, n, xncs), loc) =>
-                              (L'.SgiDatatypeImp (x, n, str, strs, x, xncs), loc)
+                            | (L'.SgiDatatype (x, n, xs, xncs), loc) =>
+                              (L'.SgiDatatypeImp (x, n, str, strs, x, xs, xncs), loc)
                             | (L'.SgiStr (x, n, sgn), loc) =>
                               (L'.SgiStr (x, n, selfify env {str = str, strs = strs @ [x], sgn = sgn}), loc)
                             | x => x) sgis), #2 sgn)
@@ -1878,10 +1911,10 @@
                                               end
                                             | L'.SgiCon (x, n, k, c) =>
                                               (L'.DCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc)
-                                            | L'.SgiDatatype (x, n, xncs) =>
-                                              (L'.DDatatypeImp (x, n, str, strs, x, xncs), loc)
-                                            | L'.SgiDatatypeImp (x, n, m1, ms, x', xncs) =>
-                                              (L'.DDatatypeImp (x, n, m1, ms, x', xncs), loc)
+                                            | L'.SgiDatatype (x, n, xs, xncs) =>
+                                              (L'.DDatatypeImp (x, n, str, strs, x, xs, xncs), loc)
+                                            | L'.SgiDatatypeImp (x, n, m1, ms, x', xs, xncs) =>
+                                              (L'.DDatatypeImp (x, n, m1, ms, x', xs, xncs), loc)
                                             | L'.SgiVal (x, n, t) =>
                                               (L'.DVal (x, n, t, (L'.EModProj (str, strs, x), loc)), loc)
                                             | L'.SgiStr (x, n, sgn) =>
@@ -1998,9 +2031,20 @@
                                      case sgi1 of
                                          L'.SgiConAbs (x', n1, k1) => found (x', n1, k1, NONE)
                                        | L'.SgiCon (x', n1, k1, c1) => found (x', n1, k1, SOME c1)
-                                       | L'.SgiDatatype (x', n1, _) => found (x', n1, (L'.KType, loc), NONE)
-                                       | L'.SgiDatatypeImp (x', n1, m1, ms, s, _) =>
-                                         found (x', n1, (L'.KType, loc), SOME (L'.CModProj (m1, ms, s), loc))
+                                       | L'.SgiDatatype (x', n1, xs, _) =>
+                                         let
+                                             val k = (L'.KType, loc)
+                                             val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs
+                                         in
+                                             found (x', n1, k', NONE)
+                                         end
+                                       | L'.SgiDatatypeImp (x', n1, m1, ms, s, xs, _) =>
+                                         let
+                                             val k = (L'.KType, loc)
+                                             val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs
+                                         in
+                                             found (x', n1, k', SOME (L'.CModProj (m1, ms, s), loc))
+                                         end
                                        | _ => NONE
                                  end)
 
@@ -2023,15 +2067,18 @@
                                          NONE
                                    | _ => NONE)
 
-                      | L'.SgiDatatype (x, n2, xncs2) =>
+                      | L'.SgiDatatype (x, n2, xs2, xncs2) =>
                         seek (fn sgi1All as (sgi1, _) =>
                                  let
-                                     fun found (n1, xncs1) =
+                                     fun found (n1, xs1, xncs1) =
                                          let
                                              fun mismatched ue =
                                                  (sgnError env (SgiMismatchedDatatypes (sgi1All, sgi2All, ue));
                                                   SOME (env, denv))
 
+                                             val k = (L'.KType, loc)
+                                             val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs1
+
                                              fun good () =
                                                  let
                                                      val env = E.sgiBinds env sgi2All
@@ -2044,6 +2091,7 @@
                                                      SOME (env, denv)
                                                  end
 
+                                             val env = foldl (fn (x, env) => E.pushCRel env x k) env xs1
                                              fun xncBad ((x1, _, t1), (x2, _, t2)) =
                                                  String.compare (x1, x2) <> EQUAL
                                                  orelse case (t1, t2) of
@@ -2052,7 +2100,8 @@
                                                             not (List.null (unifyCons (env, denv) t1 t2))
                                                           | _ => true
                                          in
-                                             (if length xncs1 <> length xncs2
+                                             (if xs1 <> xs2
+                                                 orelse length xncs1 <> length xncs2
                                                  orelse ListPair.exists xncBad (xncs1, xncs2) then
                                                   mismatched NONE
                                               else
@@ -2061,33 +2110,34 @@
                                          end
                                  in
                                      case sgi1 of
-                                         L'.SgiDatatype (x', n1, xncs1) =>
+                                         L'.SgiDatatype (x', n1, xs, xncs1) =>
                                          if x' = x then
-                                             found (n1, xncs1)
+                                             found (n1, xs, xncs1)
                                          else
                                              NONE
-                                       | L'.SgiDatatypeImp (x', n1, _, _, _, xncs1) =>
+                                       | L'.SgiDatatypeImp (x', n1, _, _, _, xs, xncs1) =>
                                          if x' = x then
-                                             found (n1, xncs1)
+                                             found (n1, xs, xncs1)
                                          else
                                              NONE
                                        | _ => NONE
                                  end)
 
-                      | L'.SgiDatatypeImp (x, n2, m12, ms2, s2, _) =>
+                      | L'.SgiDatatypeImp (x, n2, m12, ms2, s2, xs, _) =>
                         seek (fn sgi1All as (sgi1, _) =>
                                  case sgi1 of
-                                     L'.SgiDatatypeImp (x', n1, m11, ms1, s1, _) =>
+                                     L'.SgiDatatypeImp (x', n1, m11, ms1, s1, _, _) =>
                                      if x = x' then
                                          let
                                              val k = (L'.KType, loc)
+                                             val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs
                                              val t1 = (L'.CModProj (m11, ms1, s1), loc)
                                              val t2 = (L'.CModProj (m12, ms2, s2), loc)
 
                                              fun good () =
                                                  let
-                                                     val env = E.pushCNamedAs env x n1 k (SOME t1)
-                                                     val env = E.pushCNamedAs env x n2 k (SOME t2)
+                                                     val env = E.pushCNamedAs env x n1 k' (SOME t1)
+                                                     val env = E.pushCNamedAs env x n2 k' (SOME t2)
                                                  in
                                                      SOME (env, denv)
                                                  end
@@ -2213,11 +2263,17 @@
 
             ([(L'.DCon (x, n, k', c'), loc)], (env', denv, gs' @ gs))
         end
-      | L.DDatatype (x, xcs) =>
+      | L.DDatatype (x, xs, xcs) =>
         let
             val k = (L'.KType, loc)
-            val (env, n) = E.pushCNamed env x k NONE
+            val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs
+            val (env, n) = E.pushCNamed env x k' NONE
             val t = (L'.CNamed n, loc)
+            val t = ListUtil.foldli (fn (i, _, t) => (L'.CApp (t, (L'.CRel i, loc)), loc)) t xs
+
+            val (env', denv') = foldl (fn (x, (env', denv')) =>
+                                          (E.pushCRel env' x k,
+                                           D.enter denv')) (env, denv) xs
 
             val (xcs, (used, env, gs)) =
                 ListUtil.foldlMap
@@ -2227,11 +2283,12 @@
                                            NONE => (NONE, t, gs)
                                          | SOME t' =>
                                            let
-                                               val (t', tk, gs') = elabCon (env, denv) t'
+                                               val (t', tk, gs') = elabCon (env', denv') t'
                                            in
-                                               checkKind env t' tk k;
+                                               checkKind env' t' tk k;
                                                (SOME t', (L'.TFun (t', t), loc), gs' @ gs)
                                            end
+                        val t = foldr (fn (x, t) => (L'.TCFun (L'.Implicit, x, k, t), loc)) t xs
 
                         val (env, n') = E.pushENamed env x t
                     in
@@ -2243,9 +2300,9 @@
                     end)
                 (SS.empty, env, []) xcs
 
-            val env = E.pushDatatype env n xcs
+            val env = E.pushDatatype env n xs xcs
         in
-            ([(L'.DDatatype (x, n, xcs), loc)], (env, denv, gs))
+            ([(L'.DDatatype (x, n, xs, xcs), loc)], (env, denv, gs))
         end
 
       | L.DDatatypeImp (_, [], _) => raise Fail "Empty DDatatypeImp"
@@ -2268,12 +2325,13 @@
                      (case E.projectDatatype env {sgn = sgn, str = str, field = s} of
                           NONE => (conError env (UnboundDatatype (loc, s));
                                    ([], (env, denv, gs)))
-                        | SOME xncs =>
+                        | SOME (xs, xncs) =>
                           let
                               val k = (L'.KType, loc)
+                              val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs
                               val t = (L'.CModProj (n, ms, s), loc)
-                              val (env, n') = E.pushCNamed env x k (SOME t)
-                              val env = E.pushDatatype env n' xncs
+                              val (env, n') = E.pushCNamed env x k' (SOME t)
+                              val env = E.pushDatatype env n' xs xncs
 
                               val t = (L'.CNamed n', loc)
                               val env = foldl (fn ((x, n, to), env) =>
@@ -2281,11 +2339,15 @@
                                                       val t = case to of
                                                                   NONE => t
                                                                 | SOME t' => (L'.TFun (t', t), loc)
+
+                                                      val t = foldr (fn (x, t) =>
+                                                                        (L'.TCFun (L'.Implicit, x, k, t), loc))
+                                                              t xs
                                                   in
                                                       E.pushENamedAs env x n t
                                                   end) env xncs
                           in
-                              ([(L'.DDatatypeImp (x, n', n, ms, s, xncs), loc)], (env, denv, gs))
+                              ([(L'.DDatatypeImp (x, n', n, ms, s, xs, xncs), loc)], (env, denv, gs))
                           end)
                    | _ => (strError env (NotDatatype loc);
                            ([], (env, denv, [])))
@@ -2544,7 +2606,7 @@
                               in
                                   ((L'.SgiCon (x, n, k, c), loc) :: sgis, cons, vals, sgns, strs)
                               end
-                            | L'.SgiDatatype (x, n, xncs) =>
+                            | L'.SgiDatatype (x, n, xs, xncs) =>
                               let
                                   val (cons, x) =
                                       if SS.member (cons, x) then
@@ -2561,9 +2623,9 @@
                                                   ((x, n, t), SS.add (vals, x)))
                                       vals xncs
                               in
-                                  ((L'.SgiDatatype (x, n, xncs), loc) :: sgis, cons, vals, sgns, strs)
+                                  ((L'.SgiDatatype (x, n, xs, xncs), loc) :: sgis, cons, vals, sgns, strs)
                               end
-                            | L'.SgiDatatypeImp (x, n, m1, ms, x', xncs) =>
+                            | L'.SgiDatatypeImp (x, n, m1, ms, x', xs, xncs) =>
                               let
                                   val (cons, x) =
                                       if SS.member (cons, x) then
@@ -2571,7 +2633,7 @@
                                       else
                                           (SS.add (cons, x), x)
                               in
-                                  ((L'.SgiDatatypeImp (x, n, m1, ms, x', xncs), loc) :: sgis, cons, vals, sgns, strs)
+                                  ((L'.SgiDatatypeImp (x, n, m1, ms, x', xs, xncs), loc) :: sgis, cons, vals, sgns, strs)
                               end
                             | L'.SgiVal (x, n, c) =>
                               let