diff src/elaborate.sml @ 156:34ccd7d2bea8

Start of datatype support
author Adam Chlipala <adamc@hcoop.net>
date Thu, 24 Jul 2008 15:02:03 -0400
parents cfe6f9db74aa
children adc4e42e3adc
line wrap: on
line diff
--- a/src/elaborate.sml	Thu Jul 24 11:32:01 2008 -0400
+++ b/src/elaborate.sml	Thu Jul 24 15:02:03 2008 -0400
@@ -1158,6 +1158,7 @@
        | UnmatchedSgi of L'.sgn_item
        | SgiWrongKind of L'.sgn_item * L'.kind * L'.sgn_item * L'.kind * kunify_error
        | SgiWrongCon of L'.sgn_item * L'.con * L'.sgn_item * L'.con * cunify_error
+       | SgiMismatchedDatatypes of L'.sgn_item * L'.sgn_item * (L'.con * L'.con * cunify_error) option
        | SgnWrongForm of L'.sgn * L'.sgn
        | UnWhereable of L'.sgn * string
        | WhereWrongKind of L'.kind * L'.kind * kunify_error
@@ -1189,6 +1190,15 @@
                      ("Con 1", p_con env c1),
                      ("Con 2", p_con env c2)];
          cunifyError env cerr)
+      | SgiMismatchedDatatypes (sgi1, sgi2, cerro) =>
+        (ErrorMsg.errorAt (#2 sgi1) "Mismatched 'datatype' specifications:";
+         eprefaces' [("Have", p_sgn_item env sgi1),
+                     ("Need", p_sgn_item env sgi2)];
+         Option.app (fn (c1, c2, ue) =>
+                        (eprefaces "Unification error"
+                                   [("Con 1", p_con env c1),
+                                    ("Con 2", p_con env c2)];
+                         cunifyError env ue)) cerro)
       | SgnWrongForm (sgn1, sgn2) =>
         (ErrorMsg.errorAt (#2 sgn1) "Incompatible signatures:";
          eprefaces' [("Sig 1", p_sgn env sgn1),
@@ -1223,6 +1233,7 @@
        | FunctorRebind of ErrorMsg.span
        | UnOpenable of L'.sgn
        | NotType of L'.kind * (L'.kind * L'.kind * kunify_error)
+       | DuplicateConstructor of string * ErrorMsg.span
 
 fun strError env err =
     case err of
@@ -1242,6 +1253,8 @@
                      ("Subkind 1", p_kind k1),
                      ("Subkind 2", p_kind k2)];
          kunifyError ue)
+      | DuplicateConstructor (x, loc) =>
+        ErrorMsg.errorAt loc ("Duplicate datatype constructor " ^ x)
 
 val hnormSgn = E.hnormSgn
 
@@ -1270,6 +1283,10 @@
             ([(L'.SgiCon (x, n, k', c'), loc)], (env', denv, gs' @ gs))
         end
 
+      | L.SgiDatatype _ => raise Fail "Elaborate SgiDatatype"
+
+      | L.SgiDatatypeImp _ => raise Fail "Elaborate SgiDatatypeImp"
+
       | L.SgiVal (x, c) =>
         let
             val (c', ck, gs') = elabCon (env, denv) c
@@ -1342,6 +1359,28 @@
                                    else
                                        ();
                                    (SS.add (cons, x), vals, sgns, strs))
+                                | L'.SgiDatatype (x, _, xncs) =>
+                                  let
+                                      val vals = foldl (fn ((x, _, _), vals) =>
+                                                           (if SS.member (vals, x) then
+                                                                sgnError env (DuplicateVal (loc, x))
+                                                            else
+                                                                ();
+                                                            SS.add (vals, x)))
+                                                       vals xncs
+                                  in
+                                      if SS.member (cons, x) then
+                                          sgnError env (DuplicateCon (loc, x))
+                                      else
+                                          ();
+                                      (SS.add (cons, x), vals, sgns, strs)
+                                  end
+                                | L'.SgiDatatypeImp (x, _, _, _, _) =>
+                                  (if SS.member (cons, x) then
+                                       sgnError env (DuplicateCon (loc, x))
+                                   else
+                                       ();
+                                   (SS.add (cons, x), vals, sgns, strs))
                                 | L'.SgiVal (x, _, _) =>
                                   (if SS.member (vals, x) then
                                        sgnError env (DuplicateVal (loc, x))
@@ -1476,6 +1515,22 @@
                                     | L'.SgiCon (x, n, k, c) =>
                                       ((L'.DCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc),
                                        (E.pushCNamedAs env' x n k (SOME c), denv'))
+                                    | L'.SgiDatatype (x, n, xncs) =>
+                                      let
+                                          val k = (L'.KType, loc)
+                                          val c = (L'.CModProj (str, strs, x), loc)
+                                      in
+                                          ((L'.DDatatypeImp (x, n, str, strs, x), loc),
+                                           (E.pushCNamedAs env' x n k (SOME c), denv'))
+                                      end
+                                    | L'.SgiDatatypeImp (x, n, m1, ms, x') =>
+                                      let
+                                          val k = (L'.KType, loc)
+                                          val c = (L'.CModProj (m1, ms, x'), loc)
+                                      in
+                                          ((L'.DCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc),
+                                           (E.pushCNamedAs env' x n k (SOME c), denv'))
+                                      end
                                     | L'.SgiVal (x, n, t) =>
                                       ((L'.DVal (x, n, t, (L'.EModProj (str, strs, x), loc)), loc),
                                        (E.pushENamedAs env' x n t, denv'))
@@ -1487,7 +1542,7 @@
                                        (E.pushSgnNamedAs env' x n sgn, denv'))
                                     | L'.SgiConstraint (c1, c2) =>
                                       ((L'.DConstraint (c1, c2), loc),
-                                       (env', denv (* D.assert env denv (c1, c2) *) )))
+                                       (env', denv)))
                               (env, denv) sgis
           | _ => (strError env (UnOpenable sgn);
                   ([], (env, denv)))
@@ -1528,6 +1583,8 @@
 fun sgiOfDecl (d, loc) =
     case d of
         L'.DCon (x, n, k, c) => [(L'.SgiCon (x, n, k, c), loc)]
+      | L'.DDatatype x => [(L'.SgiDatatype x, loc)]
+      | L'.DDatatypeImp x => [(L'.SgiDatatypeImp x, loc)]
       | L'.DVal (x, n, t, _) => [(L'.SgiVal (x, n, t), loc)]
       | L'.DValRec vis => map (fn (x, n, t, _) => (L'.SgiVal (x, n, t), loc)) vis
       | L'.DSgn (x, n, sgn) => [(L'.SgiSgn (x, n, sgn), loc)]
@@ -1551,7 +1608,7 @@
 
       | (L'.SgnConst sgis1, L'.SgnConst sgis2) =>
         let
-            fun folder (sgi2All as (sgi, _), (env, denv)) =
+            fun folder (sgi2All as (sgi, loc), (env, denv)) =
                 let
                     fun seek p =
                         let
@@ -1613,6 +1670,49 @@
                                          NONE
                                    | _ => NONE)
 
+                      | L'.SgiDatatype (x, n2, xncs2) =>
+                        seek (fn sgi1All as (sgi1, _) =>
+                                 case sgi1 of
+                                     L'.SgiDatatype (x', n1, xncs1) =>
+                                     let
+                                         fun mismatched ue =
+                                             (sgnError env (SgiMismatchedDatatypes (sgi1All, sgi2All, ue));
+                                              SOME (env, denv))
+
+                                         fun good () =
+                                             let
+                                                 val env = E.sgiBinds env sgi2All
+                                                 val env = if n1 = n2 then
+                                                               env
+                                                           else
+                                                               E.pushCNamedAs env x n1 (L'.KType, loc)
+                                                                              (SOME (L'.CNamed n1, loc))
+                                             in
+                                                 SOME (env, denv)
+                                             end
+
+                                         fun xncBad ((x1, _, t1), (x2, _, t2)) =
+                                             String.compare (x1, x2) <> EQUAL
+                                             orelse case (t1, t2) of
+                                                        (NONE, NONE) => false
+                                                      | (SOME t1, SOME t2) =>
+                                                        not (List.null (unifyCons (env, denv) t1 t2))
+                                                      | _ => true
+                                     in
+                                         (if x = x' then
+                                             if length xncs1 <> length xncs2
+                                                orelse ListPair.exists xncBad (xncs1, xncs2) then
+                                                 mismatched NONE
+                                             else
+                                                 good ()
+                                          else
+                                              NONE)
+                                         handle CUnify ue => mismatched (SOME ue)
+                                     end
+                                   | _ => NONE)
+
+                      | L'.SgiDatatypeImp _ => raise Fail "SgiDatatypeImp in subsgn"
+
                       | L'.SgiVal (x, n2, c2) =>
                         seek (fn sgi1All as (sgi1, _) =>
                                  case sgi1 of
@@ -1722,6 +1822,40 @@
 
             ([(L'.DCon (x, n, k', c'), loc)], (env', denv, gs' @ gs))
         end
+      | L.DDatatype (x, xcs) =>
+        let
+            val k = (L'.KType, loc)
+            val (env, n) = E.pushCNamed env x k NONE
+            val t = (L'.CNamed n, loc)
+
+            val (xcs, (used, env, gs)) =
+                ListUtil.foldlMap
+                (fn ((x, to), (used, env, gs)) =>
+                    let
+                        val (to, t, gs') = case to of
+                                           NONE => (NONE, t, gs)
+                                         | SOME t' =>
+                                           let
+                                               val (t', tk, gs') = elabCon (env, denv) t'
+                                           in
+                                               checkKind env t' tk k;
+                                               (SOME t', (L'.TFun (t', t), loc), gs' @ gs)
+                                           end
+
+                        val (env, n') = E.pushENamed env x t
+                    in
+                        if SS.member (used, x) then
+                            strError env (DuplicateConstructor (x, loc))
+                        else
+                            ();
+                        ((x, n', to), (SS.add (used, x), env, gs'))
+                    end)
+                (SS.empty, env, []) xcs
+        in
+            ([(L'.DDatatype (x, n, xcs), loc)], (env, denv, gs))
+        end
+
+      | L.DDatatypeImp _ => raise Fail "Elaborate DDatatypeImp"
       | L.DVal (x, co, e) =>
         let
             val (c', _, gs1) = case co of
@@ -1975,6 +2109,35 @@
                               in
                                   ((L'.SgiCon (x, n, k, c), loc) :: sgis, cons, vals, sgns, strs)
                               end
+                            | L'.SgiDatatype (x, n, xncs) =>
+                              let
+                                  val (cons, x) =
+                                      if SS.member (cons, x) then
+                                          (cons, "?" ^ x)
+                                      else
+                                          (SS.add (cons, x), x)
+
+                                  val (xncs, vals) =
+                                      ListUtil.foldlMap
+                                          (fn ((x, n, t), vals) =>
+                                              if SS.member (vals, x) then
+                                                  (("?" ^ x, n, t), vals)
+                                              else
+                                                  ((x, n, t), SS.add (vals, x)))
+                                      vals xncs
+                              in
+                                  ((L'.SgiDatatype (x, n, xncs), loc) :: sgis, cons, vals, sgns, strs)
+                              end
+                            | L'.SgiDatatypeImp (x, n, m1, ms, x') =>
+                              let
+                                  val (cons, x) =
+                                      if SS.member (cons, x) then
+                                          (cons, "?" ^ x)
+                                      else
+                                          (SS.add (cons, x), x)
+                              in
+                                  ((L'.SgiDatatypeImp (x, n, m1, ms, x'), loc) :: sgis, cons, vals, sgns, strs)
+                              end
                             | L'.SgiVal (x, n, c) =>
                               let
                                   val (vals, x) =