diff src/elaborate.sml @ 88:7bab29834cd6

Constraints in modules
author Adam Chlipala <adamc@hcoop.net>
date Tue, 01 Jul 2008 15:58:02 -0400
parents 7f9bcc8bfa1e
children 94ef20a31550
line wrap: on
line diff
--- a/src/elaborate.sml	Tue Jul 01 13:23:46 2008 -0400
+++ b/src/elaborate.sml	Tue Jul 01 15:58:02 2008 -0400
@@ -715,7 +715,9 @@
         fun isRecord () = unifyRecordCons (env, denv) (c1All, c2All)
     in
         case (c1, c2) of
-            (L'.TFun (d1, r1), L'.TFun (d2, r2)) =>
+            (L'.CUnit, L'.CUnit) => []
+
+          | (L'.TFun (d1, r1), L'.TFun (d2, r2)) =>
             unifyCons' (env, denv) d1 d2
             @ unifyCons' (env, denv) r1 r2
           | (L'.TCFun (expl1, x1, d1, r1), L'.TCFun (expl2, _, d2, r2)) =>
@@ -1137,6 +1139,7 @@
        | DuplicateVal of ErrorMsg.span * string
        | DuplicateSgn of ErrorMsg.span * string
        | DuplicateStr of ErrorMsg.span * string
+       | NotConstraintsable of L'.sgn
 
 fun sgnError env err =
     case err of
@@ -1183,6 +1186,9 @@
         ErrorMsg.errorAt loc ("Duplicate signature " ^ s ^ " in signature")
       | DuplicateStr (loc, s) =>
         ErrorMsg.errorAt loc ("Duplicate structure " ^ s ^ " in signature")
+      | NotConstraintsable sgn =>
+        (ErrorMsg.errorAt (#2 sgn) "Invalid signature for 'open constraints'";
+         eprefaces' [("Signature", p_sgn env sgn)])
 
 datatype str_error =
          UnboundStr of ErrorMsg.span * string
@@ -1212,7 +1218,7 @@
 
 val hnormSgn = E.hnormSgn
 
-fun elabSgn_item denv ((sgi, loc), (env, gs)) =
+fun elabSgn_item ((sgi, loc), (env, denv, gs)) =
     case sgi of
         L.SgiConAbs (x, k) =>
         let
@@ -1220,7 +1226,7 @@
 
             val (env', n) = E.pushCNamed env x k' NONE
         in
-            ([(L'.SgiConAbs (x, n, k'), loc)], (env', gs))
+            ([(L'.SgiConAbs (x, n, k'), loc)], (env', denv, gs))
         end
 
       | L.SgiCon (x, ko, c) =>
@@ -1234,7 +1240,7 @@
         in
             checkKind env c' ck k';
 
-            ([(L'.SgiCon (x, n, k', c'), loc)], (env', gs' @ gs))
+            ([(L'.SgiCon (x, n, k', c'), loc)], (env', denv, gs' @ gs))
         end
 
       | L.SgiVal (x, c) =>
@@ -1246,7 +1252,7 @@
             (unifyKinds ck ktype
              handle KUnify ue => strError env (NotType (ck, ue)));
 
-            ([(L'.SgiVal (x, n, c'), loc)], (env', gs' @ gs))
+            ([(L'.SgiVal (x, n, c'), loc)], (env', denv, gs' @ gs))
         end
 
       | L.SgiStr (x, sgn) =>
@@ -1254,7 +1260,7 @@
             val (sgn', gs') = elabSgn (env, denv) sgn
             val (env', n) = E.pushStrNamed env x sgn'
         in
-            ([(L'.SgiStr (x, n, sgn'), loc)], (env', gs' @ gs))
+            ([(L'.SgiStr (x, n, sgn'), loc)], (env', denv, gs' @ gs))
         end
 
       | L.SgiSgn (x, sgn) =>
@@ -1262,7 +1268,7 @@
             val (sgn', gs') = elabSgn (env, denv) sgn
             val (env', n) = E.pushSgnNamed env x sgn'
         in
-            ([(L'.SgiSgn (x, n, sgn'), loc)], (env', gs' @ gs))
+            ([(L'.SgiSgn (x, n, sgn'), loc)], (env', denv, gs' @ gs))
         end
 
       | L.SgiInclude sgn =>
@@ -1271,16 +1277,29 @@
         in
             case #1 (hnormSgn env sgn') of
                 L'.SgnConst sgis =>
-                (sgis, (foldl (fn (sgi, env) => E.sgiBinds env sgi) env sgis, gs' @ gs))
+                (sgis, (foldl (fn (sgi, env) => E.sgiBinds env sgi) env sgis, denv, gs' @ gs))
               | _ => (sgnError env (NotIncludable sgn');
-                      ([], (env, [])))
+                      ([], (env, denv, [])))
+        end
+
+      | L.SgiConstraint (c1, c2) =>
+        let
+            val (c1', k1, gs1) = elabCon (env, denv) c1
+            val (c2', k2, gs2) = elabCon (env, denv) c2
+
+            val denv = D.assert env denv (c1', c2')
+        in
+            checkKind env c1' k1 (L'.KRecord (kunif loc), loc);
+            checkKind env c2' k2 (L'.KRecord (kunif loc), loc);
+
+            ([(L'.SgiConstraint (c1', c2'), loc)], (env, denv, gs1 @ gs2))
         end
 
 and elabSgn (env, denv) (sgn, loc) =
     case sgn of
         L.SgnConst sgis =>
         let
-            val (sgis', (_, gs)) = ListUtil.foldlMapConcat (elabSgn_item denv) (env, []) sgis
+            val (sgis', (_, _, gs)) = ListUtil.foldlMapConcat elabSgn_item (env, denv, []) sgis
 
             val _ = foldl (fn ((sgi, loc), (cons, vals, sgns, strs)) =>
                               case sgi of
@@ -1313,7 +1332,8 @@
                                        sgnError env (DuplicateStr (loc, x))
                                    else
                                        ();
-                                   (cons, vals, sgns, SS.add (strs, x))))
+                                   (cons, vals, sgns, SS.add (strs, x)))
+                                | L'.SgiConstraint _ => (cons, vals, sgns, strs))
                     (SS.empty, SS.empty, SS.empty, SS.empty) sgis'
         in
             ((L'.SgnConst sgis', loc), gs)
@@ -1410,35 +1430,65 @@
           | SOME (str, strs) => selfify env {sgn = sgn, str = str, strs = strs}
     end
 
-fun dopen env {str, strs, sgn} =
+fun dopen (env, denv) {str, strs, sgn} =
     let
         val m = foldl (fn (m, str) => (L'.StrProj (str, m), #2 sgn))
                 (L'.StrVar str, #2 sgn) strs
     in
         case #1 (hnormSgn env sgn) of
             L'.SgnConst sgis =>
-            ListUtil.foldlMap (fn ((sgi, loc), env') =>
+            ListUtil.foldlMap (fn ((sgi, loc), (env', denv')) =>
                                   case sgi of
                                       L'.SgiConAbs (x, n, k) =>
-                                      ((L'.DCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc),
-                                       E.pushCNamedAs env' x n k NONE)
+                                      let
+                                          val c = (L'.CModProj (str, strs, x), loc)
+                                      in
+                                          ((L'.DCon (x, n, k, c), loc),
+                                           (E.pushCNamedAs env' x n k (SOME c), denv'))
+                                      end
                                     | L'.SgiCon (x, n, k, c) =>
                                       ((L'.DCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc),
-                                       E.pushCNamedAs env' x n k (SOME c))
+                                       (E.pushCNamedAs env' x n k (SOME c), denv'))
                                     | L'.SgiVal (x, n, t) =>
                                       ((L'.DVal (x, n, t, (L'.EModProj (str, strs, x), loc)), loc),
-                                       E.pushENamedAs env' x n t)
+                                       (E.pushENamedAs env' x n t, denv'))
                                     | L'.SgiStr (x, n, sgn) =>
                                       ((L'.DStr (x, n, sgn, (L'.StrProj (m, x), loc)), loc),
-                                       E.pushStrNamedAs env' x n sgn)
+                                       (E.pushStrNamedAs env' x n sgn, denv'))
                                     | L'.SgiSgn (x, n, sgn) =>
                                       ((L'.DSgn (x, n, (L'.SgnProj (str, strs, x), loc)), loc),
-                                       E.pushSgnNamedAs env' x n sgn))
-                              env sgis
+                                       (E.pushSgnNamedAs env' x n sgn, denv'))
+                                    | L'.SgiConstraint (c1, c2) =>
+                                      ((L'.DConstraint (c1, c2), loc),
+                                       (env', denv (* D.assert env denv (c1, c2) *) )))
+                              (env, denv) sgis
           | _ => (strError env (UnOpenable sgn);
-                  ([], env))
+                  ([], (env, denv)))
     end
 
+fun dopenConstraints (loc, env, denv) {str, strs} =
+    case E.lookupStr env str of
+        NONE => (strError env (UnboundStr (loc, str));
+                 denv)
+      | SOME (n, sgn) =>
+        let
+            val (st, sgn) = foldl (fn (m, (str, sgn)) =>
+                                      case E.projectStr env {str = str, sgn = sgn, field = m} of
+                                          NONE => (strError env (UnboundStr (loc, m));
+                                                   (strerror, sgnerror))
+                                        | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
+                                  ((L'.StrVar n, loc), sgn) strs
+                            
+            val cso = E.projectConstraints env {sgn = sgn, str = st}
+
+            val denv = case cso of
+                           NONE => (strError env (UnboundStr (loc, str));
+                                    denv)
+                         | SOME cs => foldl (fn ((c1, c2), denv) => D.assert env denv (c1, c2)) denv cs
+        in
+            denv
+        end
+
 fun sgiOfDecl (d, loc) =
     case d of
         L'.DCon (x, n, k, c) => (L'.SgiCon (x, n, k, c), loc)
@@ -1446,6 +1496,12 @@
       | L'.DSgn (x, n, sgn) => (L'.SgiSgn (x, n, sgn), loc)
       | L'.DStr (x, n, sgn, _) => (L'.SgiStr (x, n, sgn), loc)
       | L'.DFfiStr (x, n, sgn) => (L'.SgiStr (x, n, sgn), loc)
+      | L'.DConstraint cs => (L'.SgiConstraint cs, loc)
+
+fun sgiBindsD (env, denv) (sgi, _) =
+    case sgi of
+        L'.SgiConstraint (c1, c2) => D.assert env denv (c1, c2)
+      | _ => denv
 
 fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) =
     case (#1 (hnormSgn env sgn1), #1 (hnormSgn env sgn2)) of
@@ -1454,20 +1510,20 @@
 
       | (L'.SgnConst sgis1, L'.SgnConst sgis2) =>
         let
-            fun folder (sgi2All as (sgi, _), env) =
+            fun folder (sgi2All as (sgi, _), (env, denv)) =
                 let
                     fun seek p =
                         let
-                            fun seek env ls =
+                            fun seek (env, denv) ls =
                                 case ls of
                                     [] => (sgnError env (UnmatchedSgi sgi2All);
-                                           env)
+                                           (env, denv))
                                   | h :: t =>
                                     case p h of
-                                        NONE => seek (E.sgiBinds env h) t
-                                      | SOME env => env
+                                        NONE => seek (E.sgiBinds env h, sgiBindsD (env, denv) h) t
+                                      | SOME envs => envs
                         in
-                            seek env sgis1
+                            seek (env, denv) sgis1
                         end
                 in
                     case sgi of
@@ -1485,7 +1541,8 @@
                                                  SOME (if n1 = n2 then
                                                            env
                                                        else
-                                                           E.pushCNamedAs env x n2 k2 (SOME (L'.CNamed n1, loc2)))
+                                                           E.pushCNamedAs env x n2 k2 (SOME (L'.CNamed n1, loc2)),
+                                                       denv)
                                              end
                                          else
                                              NONE
@@ -1502,7 +1559,7 @@
                                      L'.SgiCon (x', n1, k1, c1) =>
                                      if x = x' then
                                          let
-                                             fun good () = SOME (E.pushCNamedAs env x n2 k2 (SOME c2))
+                                             fun good () = SOME (E.pushCNamedAs env x n2 k2 (SOME c2), denv)
                                          in
                                              (case unifyCons (env, denv) c1 c2 of
                                                   [] => good ()
@@ -1521,11 +1578,11 @@
                                      L'.SgiVal (x', n1, c1) =>
                                      if x = x' then
                                          (case unifyCons (env, denv) c1 c2 of
-                                              [] => SOME env
+                                              [] => SOME (env, denv)
                                             | _ => NONE)
                                          handle CUnify (c1, c2, err) =>
                                                 (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err));
-                                                 SOME env)
+                                                 SOME (env, denv))
                                      else
                                          NONE
                                    | _ => NONE)
@@ -1545,7 +1602,7 @@
                                                                             (selfifyAt env {str = (L'.StrVar n1, #2 sgn2),
                                                                                             sgn = sgn2})
                                          in
-                                             SOME env
+                                             SOME (env, denv)
                                          end
                                      else
                                          NONE
@@ -1566,14 +1623,24 @@
                                                        else
                                                            E.pushSgnNamedAs env x n1 sgn2
                                          in
-                                             SOME env
+                                             SOME (env, denv)
                                          end
                                      else
                                          NONE
                                    | _ => NONE)
+
+                      | L'.SgiConstraint (c2, d2) =>
+                        seek (fn sgi1All as (sgi1, _) =>
+                                 case sgi1 of
+                                     L'.SgiConstraint (c1, d1) =>
+                                     if consEq (env, denv) (c1, c2) andalso consEq (env, denv) (d1, d2) then
+                                         SOME (env, D.assert env denv (c2, d2))
+                                     else
+                                         NONE
+                                   | _ => NONE)
                 end
         in
-            ignore (foldl folder env sgis2)
+            ignore (foldl folder (env, denv) sgis2)
         end
 
       | (L'.SgnFun (m1, n1, dom1, ran1), L'.SgnFun (m2, n2, dom2, ran2)) =>
@@ -1591,7 +1658,7 @@
       | _ => sgnError env (SgnWrongForm (sgn1, sgn2))
 
 
-fun elabDecl denv ((d, loc), (env, gs)) =
+fun elabDecl ((d, loc), (env, denv, gs)) =
     case d of
         L.DCon (x, ko, c) =>
         let
@@ -1604,7 +1671,7 @@
         in
             checkKind env c' ck k';
 
-            ([(L'.DCon (x, n, k', c'), loc)], (env', gs' @ gs))
+            ([(L'.DCon (x, n, k', c'), loc)], (env', denv, gs' @ gs))
         end
       | L.DVal (x, co, e) =>
         let
@@ -1617,7 +1684,7 @@
 
             val gs3 = checkCon (env, denv) e' et c'
         in
-            ([(L'.DVal (x, n, c', e'), loc)], (env', gs1 @ gs2 @ gs3 @ gs))
+            ([(L'.DVal (x, n, c', e'), loc)], (env', denv, gs1 @ gs2 @ gs3 @ gs))
         end
 
       | L.DSgn (x, sgn) =>
@@ -1625,7 +1692,7 @@
             val (sgn', gs') = elabSgn (env, denv) sgn
             val (env', n) = E.pushSgnNamed env x sgn'
         in
-            ([(L'.DSgn (x, n, sgn'), loc)], (env', gs' @ gs))
+            ([(L'.DSgn (x, n, sgn'), loc)], (env', denv, gs' @ gs))
         end
 
       | L.DStr (x, sgno, str) =>
@@ -1691,7 +1758,7 @@
                    | _ => strError env (FunctorRebind loc))
               | _ => ();
 
-            ([(L'.DStr (x, n, sgn', str'), loc)], (env', gs' @ gs))
+            ([(L'.DStr (x, n, sgn', str'), loc)], (env', denv, gs' @ gs))
         end
 
       | L.DFfiStr (x, sgn) =>
@@ -1700,32 +1767,54 @@
 
             val (env', n) = E.pushStrNamed env x sgn'
         in
-            ([(L'.DFfiStr (x, n, sgn'), loc)], (env', gs' @ gs))
+            ([(L'.DFfiStr (x, n, sgn'), loc)], (env', denv, gs' @ gs))
         end
 
       | L.DOpen (m, ms) =>
-        case E.lookupStr env m of
-            NONE => (strError env (UnboundStr (loc, m));
-                     ([], (env, [])))
-          | SOME (n, sgn) =>
-            let
-                val (_, sgn) = foldl (fn (m, (str, sgn)) =>
-                                         case E.projectStr env {str = str, sgn = sgn, field = m} of
-                                             NONE => (strError env (UnboundStr (loc, m));
-                                                      (strerror, sgnerror))
-                                           | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
-                                     ((L'.StrVar n, loc), sgn) ms
+        (case E.lookupStr env m of
+             NONE => (strError env (UnboundStr (loc, m));
+                      ([], (env, denv, [])))
+           | SOME (n, sgn) =>
+             let
+                 val (_, sgn) = foldl (fn (m, (str, sgn)) =>
+                                          case E.projectStr env {str = str, sgn = sgn, field = m} of
+                                              NONE => (strError env (UnboundStr (loc, m));
+                                                       (strerror, sgnerror))
+                                            | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
+                                      ((L'.StrVar n, loc), sgn) ms
 
-                val (ds, env') = dopen env {str = n, strs = ms, sgn = sgn}
-            in
-                (ds, (env', []))
-            end
+                 val (ds, (env', denv')) = dopen (env, denv) {str = n, strs = ms, sgn = sgn}
+                 val denv' = dopenConstraints (loc, env', denv') {str = m, strs = ms}
+             in
+                 (ds, (env', denv', []))
+             end)
+
+      | L.DConstraint (c1, c2) =>
+        let
+            val (c1', k1, gs1) = elabCon (env, denv) c1
+            val (c2', k2, gs2) = elabCon (env, denv) c2
+            val gs3 = map (fn cs => (loc, env, denv, cs)) (D.prove env denv (c1', c2', loc))
+
+            val denv' = D.assert env denv (c1', c2')
+        in
+            checkKind env c1' k1 (L'.KRecord (kunif loc), loc);
+            checkKind env c2' k2 (L'.KRecord (kunif loc), loc);
+
+            ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3))
+        end
+
+      | L.DOpenConstraints (m, ms) =>
+        let
+            val denv = dopenConstraints (loc, env, denv) {str = m, strs = ms}
+        in
+            ([], (env, denv, []))
+        end
 
 and elabStr (env, denv) (str, loc) =
     case str of
         L.StrConst ds =>
         let
-            val (ds', (env', gs)) = ListUtil.foldlMapConcat (elabDecl denv) (env, []) ds
+            val (ds', (_, _, gs)) = ListUtil.foldlMapConcat elabDecl (env, denv, []) ds
             val sgis = map sgiOfDecl ds'
 
             val (sgis, _, _, _, _) =
@@ -1781,7 +1870,8 @@
                                           (SS.add (strs, x), x)
                               in
                                   ((L'.SgiStr (x, n, sgn), loc) :: sgis, cons, vals, sgns, strs)
-                              end)
+                              end
+                            | L'.SgiConstraint _ => ((sgi, loc) :: sgis, cons, vals, sgns, strs))
 
                 ([], SS.empty, SS.empty, SS.empty, SS.empty) sgis
         in
@@ -1852,7 +1942,7 @@
 
         val (env', basis_n) = E.pushStrNamed env "Basis" sgn
 
-        val (ds, env') = dopen env' {str = basis_n, strs = [], sgn = sgn}
+        val (ds, (env', _)) = dopen (env', D.empty) {str = basis_n, strs = [], sgn = sgn}
 
         fun discoverC r x =
             case E.lookupC env' x of
@@ -1868,7 +1958,7 @@
             let
                 val () = resetKunif ()
                 val () = resetCunif ()
-                val (ds, (env, gs)) = elabDecl D.empty (d, (env, gs))
+                val (ds, (env, _, gs)) = elabDecl (d, (env, D.empty, gs))
             in
                 if ErrorMsg.anyErrors () then
                     ()