diff src/elab_util.sml @ 623:588b9d16b00a

Start of kind polymorphism, up to the point where demo/hello elaborates with updated Basis/Top
author Adam Chlipala <adamc@hcoop.net>
date Sun, 22 Feb 2009 16:10:25 -0500
parents 8998114760c1
children 12b73f3c108e
line wrap: on
line diff
--- a/src/elab_util.sml	Sat Feb 21 16:11:56 2009 -0500
+++ b/src/elab_util.sml	Sun Feb 22 16:10:25 2009 -0500
@@ -43,44 +43,60 @@
 
 structure Kind = struct
 
-fun mapfold f =
+fun mapfoldB {kind, bind} =
     let
-        fun mfk k acc =
-            S.bindP (mfk' k acc, f)
+        fun mfk ctx k acc =
+            S.bindP (mfk' ctx k acc, kind ctx)
 
-        and mfk' (kAll as (k, loc)) =
+        and mfk' ctx (kAll as (k, loc)) =
             case k of
                 KType => S.return2 kAll
 
               | KArrow (k1, k2) =>
-                S.bind2 (mfk k1,
+                S.bind2 (mfk ctx k1,
                       fn k1' =>
-                         S.map2 (mfk k2,
+                         S.map2 (mfk ctx k2,
                               fn k2' =>
                                  (KArrow (k1', k2'), loc)))
 
               | KName => S.return2 kAll
 
               | KRecord k =>
-                S.map2 (mfk k,
+                S.map2 (mfk ctx k,
                         fn k' =>
                            (KRecord k', loc))
 
               | KUnit => S.return2 kAll
 
               | KTuple ks =>
-                S.map2 (ListUtil.mapfold mfk ks,
+                S.map2 (ListUtil.mapfold (mfk ctx) ks,
                         fn ks' =>
                            (KTuple ks', loc))
 
               | KError => S.return2 kAll
 
-              | KUnif (_, _, ref (SOME k)) => mfk' k
+              | KUnif (_, _, ref (SOME k)) => mfk' ctx k
               | KUnif _ => S.return2 kAll
+
+              | KRel _ => S.return2 kAll
+              | KFun (x, k) =>
+                S.map2 (mfk (bind (ctx, x)) k,
+                        fn k' =>
+                           (KFun (x, k'), loc))
     in
         mfk
     end
 
+fun mapfold fk =
+    mapfoldB {kind = fn () => fk,
+              bind = fn ((), _) => ()} ()
+
+fun mapB {kind, bind} ctx k =
+    case mapfoldB {kind = fn ctx => fn k => fn () => S.Continue (kind ctx k, ()),
+                   bind = bind} ctx k () of
+        S.Continue (k, ()) => k
+      | S.Return _ => raise Fail "ElabUtil.Kind.mapB: Impossible"
+
 fun exists f k =
     case mapfold (fn k => fn () =>
                              if f k then
@@ -95,12 +111,13 @@
 structure Con = struct
 
 datatype binder =
-         Rel of string * Elab.kind
-       | Named of string * int * Elab.kind
+         RelK of string
+       | RelC of string * Elab.kind
+       | NamedC of string * int * Elab.kind
 
 fun mapfoldB {kind = fk, con = fc, bind} =
     let
-        val mfk = Kind.mapfold fk
+        val mfk = Kind.mapfoldB {kind = fk, bind = fn (ctx, s) => bind (ctx, RelK s)}
 
         fun mfc ctx c acc =
             S.bindP (mfc' ctx c acc, fc ctx)
@@ -114,9 +131,9 @@
                               fn c2' =>
                                  (TFun (c1', c2'), loc)))
               | TCFun (e, x, k, c) =>
-                S.bind2 (mfk k,
+                S.bind2 (mfk ctx k,
                       fn k' =>
-                         S.map2 (mfc (bind (ctx, Rel (x, k))) c,
+                         S.map2 (mfc (bind (ctx, RelC (x, k))) c,
                               fn c' =>
                                  (TCFun (e, x, k', c'), loc)))
               | CDisjoint (ai, c1, c2, c3) =>
@@ -142,16 +159,16 @@
                               fn c2' =>
                                  (CApp (c1', c2'), loc)))
               | CAbs (x, k, c) =>
-                S.bind2 (mfk k,
+                S.bind2 (mfk ctx k,
                       fn k' =>
-                         S.map2 (mfc (bind (ctx, Rel (x, k))) c,
+                         S.map2 (mfc (bind (ctx, RelC (x, k))) c,
                               fn c' =>
                                  (CAbs (x, k', c'), loc)))
 
               | CName _ => S.return2 cAll
 
               | CRecord (k, xcs) =>
-                S.bind2 (mfk k,
+                S.bind2 (mfk ctx k,
                       fn k' =>
                          S.map2 (ListUtil.mapfold (fn (x, c) =>
                                                       S.bind2 (mfc ctx x,
@@ -169,9 +186,9 @@
                               fn c2' =>
                                  (CConcat (c1', c2'), loc)))
               | CMap (k1, k2) =>
-                S.bind2 (mfk k1,
+                S.bind2 (mfk ctx k1,
                          fn k1' =>
-                            S.map2 (mfk k2,
+                            S.map2 (mfk ctx k2,
                                     fn k2' =>
                                        (CMap (k1', k2'), loc)))
 
@@ -190,17 +207,32 @@
               | CError => S.return2 cAll
               | CUnif (_, _, _, ref (SOME c)) => mfc' ctx c
               | CUnif _ => S.return2 cAll
+
+              | CKAbs (x, c) =>
+                S.map2 (mfc (bind (ctx, RelK x)) c,
+                        fn c' =>
+                           (CKAbs (x, c'), loc))
+              | CKApp (c, k) =>
+                S.bind2 (mfc ctx c,
+                      fn c' =>
+                         S.map2 (mfk ctx k,
+                                 fn k' =>
+                                    (CKApp (c', k'), loc)))
+              | TKFun (x, c) =>
+                S.map2 (mfc (bind (ctx, RelK x)) c,
+                        fn c' =>
+                           (TKFun (x, c'), loc))
     in
         mfc
     end
 
 fun mapfold {kind = fk, con = fc} =
-    mapfoldB {kind = fk,
+    mapfoldB {kind = fn () => fk,
               con = fn () => fc,
               bind = fn ((), _) => ()} ()
 
 fun mapB {kind, con, bind} ctx c =
-    case mapfoldB {kind = fn k => fn () => S.Continue (kind k, ()),
+    case mapfoldB {kind = fn ctx => fn k => fn () => S.Continue (kind ctx k, ()),
                    con = fn ctx => fn c => fn () => S.Continue (con ctx c, ()),
                    bind = bind} ctx c () of
         S.Continue (c, ()) => c
@@ -227,7 +259,7 @@
       | S.Continue _ => false
 
 fun foldB {kind, con, bind} ctx st c =
-    case mapfoldB {kind = fn k => fn st => S.Continue (k, kind (k, st)),
+    case mapfoldB {kind = fn ctx => fn k => fn st => S.Continue (k, kind (ctx, k, st)),
                    con = fn ctx => fn c => fn st => S.Continue (c, con (ctx, c, st)),
                    bind = bind} ctx c st of
         S.Continue (_, st) => st
@@ -238,20 +270,22 @@
 structure Exp = struct
 
 datatype binder =
-         RelC of string * Elab.kind
+         RelK of string
+       | RelC of string * Elab.kind
        | NamedC of string * int * Elab.kind
        | RelE of string * Elab.con
        | NamedE of string * Elab.con
 
 fun mapfoldB {kind = fk, con = fc, exp = fe, bind} =
     let
-        val mfk = Kind.mapfold fk
+        val mfk = Kind.mapfoldB {kind = fk, bind = fn (ctx, x) => bind (ctx, RelK x)}
 
         fun bind' (ctx, b) =
             let
                 val b' = case b of
-                             Con.Rel x => RelC x
-                           | Con.Named x => NamedC x
+                             Con.RelK x => RelK x
+                           | Con.RelC x => RelC x
+                           | Con.NamedC x => NamedC x
             in
                 bind (ctx, b')
             end
@@ -288,7 +322,7 @@
                               fn c' =>
                                  (ECApp (e', c'), loc)))
               | ECAbs (expl, x, k, e) =>
-                S.bind2 (mfk k,
+                S.bind2 (mfk ctx k,
                       fn k' =>
                          S.map2 (mfe (bind (ctx, RelC (x, k))) e,
                               fn e' =>
@@ -347,11 +381,6 @@
                                       fn rest' =>
                                          (ECutMulti (e', c', {rest = rest'}), loc))))
 
-              | EFold k =>
-                S.map2 (mfk k,
-                         fn k' =>
-                            (EFold k', loc))
-
               | ECase (e, pes, {disc, result}) =>
                 S.bind2 (mfe ctx e,
                          fn e' =>
@@ -406,6 +435,17 @@
                                        (ELet (des', e'), loc)))
                 end
 
+              | EKAbs (x, e) =>
+                S.map2 (mfe (bind (ctx, RelK x)) e,
+                        fn e' =>
+                           (EKAbs (x, e'), loc))
+              | EKApp (e, k) =>
+                S.bind2 (mfe ctx e,
+                        fn e' =>
+                           S.map2 (mfk ctx k,
+                                   fn k' =>
+                                      (EKApp (e', k'), loc)))
+
         and mfed ctx (dAll as (d, loc)) =
             case d of
                 EDVal vi =>
@@ -432,7 +472,7 @@
     end
 
 fun mapfold {kind = fk, con = fc, exp = fe} =
-    mapfoldB {kind = fk,
+    mapfoldB {kind = fn () => fk,
               con = fn () => fc,
               exp = fn () => fe,
               bind = fn ((), _) => ()} ()
@@ -457,7 +497,7 @@
       | S.Continue _ => false
 
 fun mapB {kind, con, exp, bind} ctx e =
-    case mapfoldB {kind = fn k => fn () => S.Continue (kind k, ()),
+    case mapfoldB {kind = fn ctx => fn k => fn () => S.Continue (kind ctx k, ()),
                    con = fn ctx => fn c => fn () => S.Continue (con ctx c, ()),
                    exp = fn ctx => fn e => fn () => S.Continue (exp ctx e, ()),
                    bind = bind} ctx e () of
@@ -465,7 +505,7 @@
       | S.Return _ => raise Fail "ElabUtil.Exp.mapB: Impossible"
 
 fun foldB {kind, con, exp, bind} ctx st e =
-    case mapfoldB {kind = fn k => fn st => S.Continue (k, kind (k, st)),
+    case mapfoldB {kind = fn ctx => fn k => fn st => S.Continue (k, kind (ctx, k, st)),
                    con = fn ctx => fn c => fn st => S.Continue (c, con (ctx, c, st)),
                    exp = fn ctx => fn e => fn st => S.Continue (e, exp (ctx, e, st)),
                    bind = bind} ctx e st of
@@ -477,7 +517,8 @@
 structure Sgn = struct
 
 datatype binder =
-         RelC of string * Elab.kind
+         RelK of string
+       | RelC of string * Elab.kind
        | NamedC of string * int * Elab.kind
        | Str of string * Elab.sgn
        | Sgn of string * Elab.sgn
@@ -487,14 +528,15 @@
         fun bind' (ctx, b) =
             let
                 val b' = case b of
-                             Con.Rel x => RelC x
-                           | Con.Named x => NamedC x
+                             Con.RelK x => RelK x
+                           | Con.RelC x => RelC x
+                           | Con.NamedC x => NamedC x
             in
                 bind (ctx, b')
             end
         val con = Con.mapfoldB {kind = kind, con = con, bind = bind'}
 
-        val kind = Kind.mapfold kind
+        val kind = Kind.mapfoldB {kind = kind, bind = fn (ctx, x) => bind (ctx, RelK x)}
 
         fun sgi ctx si acc =
             S.bindP (sgi' ctx si acc, sgn_item ctx)
@@ -502,11 +544,11 @@
         and sgi' ctx (siAll as (si, loc)) =
             case si of
                 SgiConAbs (x, n, k) =>
-                S.map2 (kind k,
+                S.map2 (kind ctx k,
                      fn k' =>
                         (SgiConAbs (x, n, k'), loc))
               | SgiCon (x, n, k, c) =>
-                S.bind2 (kind k,
+                S.bind2 (kind ctx k,
                      fn k' =>
                         S.map2 (con ctx c,
                              fn c' =>
@@ -548,11 +590,11 @@
                                     fn c2' =>
                                        (SgiConstraint (c1', c2'), loc)))
               | SgiClassAbs (x, n, k) =>
-                S.map2 (kind k,
+                S.map2 (kind ctx k,
                         fn k' =>
                            (SgiClassAbs (x, n, k'), loc))
               | SgiClass (x, n, k, c) =>
-                S.bind2 (kind k,
+                S.bind2 (kind ctx k,
                       fn k' => 
                          S.map2 (con ctx c,
                               fn c' =>
@@ -608,7 +650,7 @@
     end
 
 fun mapfold {kind, con, sgn_item, sgn} =
-    mapfoldB {kind = kind,
+    mapfoldB {kind = fn () => kind,
               con = fn () => con,
               sgn_item = fn () => sgn_item,
               sgn = fn () => sgn,
@@ -627,7 +669,8 @@
 structure Decl = struct
 
 datatype binder =
-         RelC of string * Elab.kind
+         RelK of string
+       | RelC of string * Elab.kind
        | NamedC of string * int * Elab.kind
        | RelE of string * Elab.con
        | NamedE of string * Elab.con
@@ -636,13 +679,14 @@
 
 fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = fst, decl = fd, bind} =
     let
-        val mfk = Kind.mapfold fk
+        val mfk = Kind.mapfoldB {kind = fk, bind = fn (ctx, x) => bind (ctx, RelK x)}
 
         fun bind' (ctx, b) =
             let
                 val b' = case b of
-                             Con.Rel x => RelC x
-                           | Con.Named x => NamedC x
+                             Con.RelK x => RelK x
+                           | Con.RelC x => RelC x
+                           | Con.NamedC x => NamedC x
             in
                 bind (ctx, b')
             end
@@ -651,7 +695,8 @@
         fun bind' (ctx, b) =
             let
                 val b' = case b of
-                             Exp.RelC x => RelC x
+                             Exp.RelK x => RelK x
+                           | Exp.RelC x => RelC x
                            | Exp.NamedC x => NamedC x
                            | Exp.RelE x => RelE x
                            | Exp.NamedE x => NamedE x
@@ -663,7 +708,8 @@
         fun bind' (ctx, b) =
             let
                 val b' = case b of
-                             Sgn.RelC x => RelC x
+                             Sgn.RelK x => RelK x
+                           | Sgn.RelC x => RelC x
                            | Sgn.NamedC x => NamedC x
                            | Sgn.Sgn x => Sgn x
                            | Sgn.Str x => Str x
@@ -760,7 +806,7 @@
         and mfd' ctx (dAll as (d, loc)) =
             case d of
                 DCon (x, n, k, c) =>
-                S.bind2 (mfk k,
+                S.bind2 (mfk ctx k,
                          fn k' =>
                             S.map2 (mfc ctx c,
                                     fn c' =>
@@ -825,7 +871,7 @@
               | DSequence _ => S.return2 dAll
 
               | DClass (x, n, k, c) =>
-                S.bind2 (mfk k,
+                S.bind2 (mfk ctx k,
                          fn k' =>
                             S.map2 (mfc ctx c,
                                  fn c' =>
@@ -849,7 +895,7 @@
     end
 
 fun mapfold {kind, con, exp, sgn_item, sgn, str, decl} =
-    mapfoldB {kind = kind,
+    mapfoldB {kind = fn () => kind,
               con = fn () => con,
               exp = fn () => exp,
               sgn_item = fn () => sgn_item,
@@ -938,7 +984,7 @@
       | S.Continue _ => NONE
 
 fun foldMapB {kind, con, exp, sgn_item, sgn, str, decl, bind} ctx st d =
-    case mapfoldB {kind = fn x => fn st => S.Continue (kind (x, st)),
+    case mapfoldB {kind = fn ctx => fn x => fn st => S.Continue (kind (ctx, x, st)),
                    con = fn ctx => fn x => fn st => S.Continue (con (ctx, x, st)),
                    exp = fn ctx => fn x => fn st => S.Continue (exp (ctx, x, st)),
                    sgn_item = fn ctx => fn x => fn st => S.Continue (sgn_item (ctx, x, st)),