changeset 158:b4b70de488e9

More datatype module stuff
author Adam Chlipala <adamc@hcoop.net>
date Thu, 24 Jul 2008 16:36:41 -0400
parents adc4e42e3adc
children 1e382d10e832
files src/elab_env.sig src/elab_env.sml src/elaborate.sml tests/datatypeMod.lac
diffstat 4 files changed, 168 insertions(+), 26 deletions(-) [+]
line wrap: on
line diff
--- a/src/elab_env.sig	Thu Jul 24 15:49:30 2008 -0400
+++ b/src/elab_env.sig	Thu Jul 24 16:36:41 2008 -0400
@@ -93,4 +93,6 @@
 
     val newNamed : unit -> int
 
+    val chaseMpath : env -> (int * string list) -> Elab.str * Elab.sgn
+
 end
--- a/src/elab_env.sml	Thu Jul 24 15:49:30 2008 -0400
+++ b/src/elab_env.sml	Thu Jul 24 16:36:41 2008 -0400
@@ -366,7 +366,16 @@
                 [] => NONE
               | (sgi, _) :: sgis =>
                 case f sgi of
-                    SOME v => SOME (v, (sgns, strs, cons))
+                    SOME v =>
+                    let
+                        val cons =
+                            case sgi of
+                                SgiDatatype (x, n, _) => IM.insert (cons, n, x)
+                              | SgiDatatypeImp (x, n, _, _, _) => IM.insert (cons, n, x)
+                              | _ => cons
+                    in
+                        SOME (v, (sgns, strs, cons))
+                    end
                   | NONE =>
                     case sgi of
                         SgiConAbs (x, n, _) => seek (sgis, sgns, strs, IM.insert (cons, n, x))
@@ -503,12 +512,28 @@
       | SgnError => SOME (SgnError, ErrorMsg.dummySpan)
       | _ => NONE
 
+fun chaseMpath env (n, ms) =
+    let
+        val (_, sgn) = lookupStrNamed env n
+    in
+        foldl (fn (m, (str, sgn)) =>
+                                   case projectStr env {sgn = sgn, str = str, field = m} of
+                                       NONE => raise Fail "kindof: Unknown substructure"
+                                     | SOME sgn => ((StrProj (str, m), #2 sgn), sgn))
+                               ((StrVar n, #2 sgn), sgn) ms
+    end
+
 fun projectCon env {sgn, str, field} =
     case #1 (hnormSgn env sgn) of
         SgnConst sgis =>
         (case sgnSeek (fn SgiConAbs (x, _, k) => if x = field then SOME (k, NONE) else NONE
                         | SgiCon (x, _, k, c) => if x = field then SOME (k, SOME c) else NONE
                         | SgiDatatype (x, _, _) => if x = field then SOME ((KType, #2 sgn), NONE) else NONE
+                        | SgiDatatypeImp (x, _, m1, ms, x') =>
+                          if x = field then
+                              SOME ((KType, #2 sgn), SOME (CModProj (m1, ms, x'), #2 sgn))
+                          else
+                              NONE
                         | _ => NONE) sgis of
              NONE => NONE
            | SOME ((k, co), subs) => SOME (k, Option.map (sgnSubCon (str, subs)) co))
@@ -519,6 +544,15 @@
     case #1 (hnormSgn env sgn) of
         SgnConst sgis =>
         (case sgnSeek (fn SgiDatatype (x, _, xncs) => if x = field then SOME xncs else NONE
+                        | SgiDatatypeImp (x, _, m1, ms, x') =>
+                          if x = field then
+                              let
+                                  val (str, sgn) = chaseMpath env (m1, ms)
+                              in
+                                  projectDatatype env {sgn = sgn, str = str, field = x'}
+                              end
+                          else
+                              NONE
                         | _ => NONE) sgis of
              NONE => NONE
            | SOME (xncs, subs) => SOME (map (fn (x, n, to) => (x, n, Option.map (sgnSubCon (str, subs)) to)) xncs))
@@ -527,7 +561,31 @@
 fun projectVal env {sgn, str, field} =
     case #1 (hnormSgn env sgn) of
         SgnConst sgis =>
-        (case sgnSeek (fn SgiVal (x, _, c) => if x = field then SOME c else NONE | _ => NONE) sgis of
+        (case sgnSeek (fn SgiVal (x, _, c) => if x = field then SOME c else NONE
+                        | SgiDatatype (_, n, xncs) =>
+                          ListUtil.search (fn (x, _, to) =>
+                                              if x = field then
+                                                  SOME (case to of
+                                                            NONE => (CNamed n, #2 sgn)
+                                                          | SOME t => (TFun (t, (CNamed n, #2 sgn)), #2 sgn))
+                                              else
+                                                  NONE) xncs
+                        | SgiDatatypeImp (_, n, m1, ms, x') =>
+                          let
+                              val (str, sgn) = chaseMpath env (m1, ms)
+                          in
+                              case projectDatatype env {sgn = sgn, str = str, field = x'} of
+                                  NONE => NONE
+                                | SOME xncs =>
+                                  ListUtil.search (fn (x, _, to) =>
+                                                      if x = field then
+                                                          SOME (case to of
+                                                                    NONE => (CNamed n, #2 sgn)
+                                                                  | SOME t => (TFun (t, (CNamed n, #2 sgn)), #2 sgn))
+                                                      else
+                                                          NONE) xncs
+                          end
+                        | _ => NONE) sgis of
              NONE => NONE
            | SOME (c, subs) => SOME (sgnSubCon (str, subs) c))
       | SgnError => SOME (CError, ErrorMsg.dummySpan)
--- a/src/elaborate.sml	Thu Jul 24 15:49:30 2008 -0400
+++ b/src/elaborate.sml	Thu Jul 24 16:36:41 2008 -0400
@@ -1237,6 +1237,7 @@
        | UnOpenable of L'.sgn
        | NotType of L'.kind * (L'.kind * L'.kind * kunify_error)
        | DuplicateConstructor of string * ErrorMsg.span
+       | NotDatatype of ErrorMsg.span
 
 fun strError env err =
     case err of
@@ -1258,6 +1259,8 @@
          kunifyError ue)
       | DuplicateConstructor (x, loc) =>
         ErrorMsg.errorAt loc ("Duplicate datatype constructor " ^ x)
+      | NotDatatype loc =>
+        ErrorMsg.errorAt loc "Trying to import non-datatype as a datatype"
 
 val hnormSgn = E.hnormSgn
 
@@ -1319,7 +1322,44 @@
             ([(L'.SgiDatatype (x, n, xcs), loc)], (env, denv, gs))
         end
 
-      | L.SgiDatatypeImp _ => raise Fail "Elaborate SgiDatatypeImp"
+      | L.SgiDatatypeImp (_, [], _) => raise Fail "Empty SgiDatatypeImp"
+
+      | L.SgiDatatypeImp (x, m1 :: ms, s) =>
+        (case E.lookupStr env m1 of
+             NONE => (strError env (UnboundStr (loc, m1));
+                      ([], (env, denv, gs)))
+           | SOME (n, sgn) =>
+             let
+                 val (str, sgn) = foldl (fn (m, (str, sgn)) =>
+                                     case E.projectStr env {sgn = sgn, str = str, field = m} of
+                                         NONE => (conError env (UnboundStrInCon (loc, m));
+                                                  (strerror, sgnerror))
+                                       | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
+                                  ((L'.StrVar n, loc), sgn) ms
+             in
+                 case E.projectDatatype env {sgn = sgn, str = str, field = s} of
+                     NONE => (conError env (UnboundDatatype (loc, s));
+                              ([], (env, denv, gs)))
+                   | SOME xncs =>
+                     let
+                         val k = (L'.KType, loc)
+                         val t = (L'.CModProj (n, ms, s), loc)
+                         val (env, n') = E.pushCNamed env x k (SOME t)
+                         val env = E.pushDatatype env n' xncs
+
+                         val t = (L'.CNamed n', loc)
+                         val env = foldl (fn ((x, n, to), env) =>
+                                             let
+                                                 val t = case to of
+                                                             NONE => t
+                                                           | SOME t' => (L'.TFun (t', t), loc)
+                                             in
+                                                 E.pushENamedAs env x n t
+                                             end) env xncs
+                     in
+                         ([(L'.SgiDatatypeImp (x, n', n, ms, s), loc)], (env, denv, []))
+                     end
+             end)
 
       | L.SgiVal (x, c) =>
         let
@@ -1501,6 +1541,8 @@
       | L'.SgnConst sgis =>
         (L'.SgnConst (map (fn (L'.SgiConAbs (x, n, k), loc) =>
                               (L'.SgiCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc)
+                            | (L'.SgiDatatype (x, n, xncs), loc) =>
+                              (L'.SgiDatatypeImp (x, n, str, strs, x), loc)
                             | (L'.SgiStr (x, n, sgn), loc) =>
                               (L'.SgiStr (x, n, selfify env {str = str, strs = strs @ [x], sgn = sgn}), loc)
                             | x => x) sgis), #2 sgn)
@@ -1745,7 +1787,35 @@
                                      end
                                    | _ => NONE)
 
-                      | L'.SgiDatatypeImp _ => raise Fail "SgiDatatypeImp in subsgn"
+                      | L'.SgiDatatypeImp (x, n2, m11, ms1, s1) =>
+                        seek (fn sgi1All as (sgi1, _) =>
+                                 case sgi1 of
+                                     L'.SgiDatatypeImp (x', n1, m12, ms2, s2) =>
+                                     if x = x' then
+                                         let
+                                             val k = (L'.KType, loc)
+                                             val t1 = (L'.CModProj (m11, ms1, s1), loc)
+                                             val t2 = (L'.CModProj (m12, ms2, s2), loc)
+
+                                             fun good () =
+                                                 let
+                                                     val env = E.pushCNamedAs env x n1 k (SOME t1)
+                                                     val env = E.pushCNamedAs env x n2 k (SOME t2)
+                                                 in
+                                                     SOME (env, denv)
+                                                 end
+                                         in
+                                             (case unifyCons (env, denv) t1 t2 of
+                                                  [] => good ()
+                                                | _ => NONE)
+                                             handle CUnify (c1, c2, err) =>
+                                                    (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err));
+                                                     good ())
+                                         end
+                                     else
+                                         NONE
+
+                                   | _ => NONE)
 
                       | L'.SgiVal (x, n2, c2) =>
                         seek (fn sgi1All as (sgi1, _) =>
@@ -1904,28 +1974,32 @@
                                        | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
                                   ((L'.StrVar n, loc), sgn) ms
              in
-                 case E.projectDatatype env {sgn = sgn, str = str, field = s} of
-                     NONE => (conError env (UnboundDatatype (loc, s));
-                              ([], (env, denv, gs)))
-                   | SOME xncs =>
-                     let
-                         val k = (L'.KType, loc)
-                         val t = (L'.CModProj (n, ms, s), loc)
-                         val (env, n') = E.pushCNamed env x k (SOME t)
-                         val env = E.pushDatatype env n' xncs
+                 case hnormCon (env, denv) (L'.CModProj (n, ms, s), loc) of
+                     ((L'.CModProj (n, ms, s), _), gs) =>
+                     (case E.projectDatatype env {sgn = sgn, str = str, field = s} of
+                          NONE => (conError env (UnboundDatatype (loc, s));
+                                   ([], (env, denv, gs)))
+                        | SOME xncs =>
+                          let
+                              val k = (L'.KType, loc)
+                              val t = (L'.CModProj (n, ms, s), loc)
+                              val (env, n') = E.pushCNamed env x k (SOME t)
+                              val env = E.pushDatatype env n' xncs
 
-                         val t = (L'.CNamed n', loc)
-                         val env = foldl (fn ((x, n, to), env) =>
-                                             let
-                                                 val t = case to of
-                                                             NONE => t
-                                                           | SOME t' => (L'.TFun (t', t), loc)
-                                             in
-                                                 E.pushENamedAs env x n t
-                                             end) env xncs
-                     in
-                         ([(L'.DDatatypeImp (x, n', n, ms, s), loc)], (env, denv, []))
-                     end
+                              val t = (L'.CNamed n', loc)
+                              val env = foldl (fn ((x, n, to), env) =>
+                                                  let
+                                                      val t = case to of
+                                                                  NONE => t
+                                                                | SOME t' => (L'.TFun (t', t), loc)
+                                                  in
+                                                      E.pushENamedAs env x n t
+                                                  end) env xncs
+                          in
+                              ([(L'.DDatatypeImp (x, n', n, ms, s), loc)], (env, denv, gs))
+                          end)
+                   | _ => (strError env (NotDatatype loc);
+                           ([], (env, denv, [])))
              end)
 
       | L.DVal (x, co, e) =>
@@ -2035,7 +2109,7 @@
 
                         val (str', actual, gs2) = elabStr (env, denv) str
                     in
-                        subSgn (env, denv) actual formal;
+                        subSgn (env, denv) (selfifyAt env {str = str', sgn = actual}) formal;
                         (str', formal, gs1 @ gs2)
                     end
 
--- a/tests/datatypeMod.lac	Thu Jul 24 15:49:30 2008 -0400
+++ b/tests/datatypeMod.lac	Thu Jul 24 16:36:41 2008 -0400
@@ -2,7 +2,15 @@
         datatype t = A | B
 end
 
+val a = M.A
+
 datatype u = datatype M.t
 
 val a : M.t = A
 val a2 : u = a
+
+structure M2 = M
+structure M3 : sig datatype t = datatype M.t end = M2
+structure M4 : sig datatype t = datatype M.t end = M
+
+val b : M3.t = M4.B