diff src/elab_env.sml @ 211:e86411f647c6

Initial type class support
author Adam Chlipala <adamc@hcoop.net>
date Sat, 16 Aug 2008 14:32:18 -0400
parents cb8f69556975
children 0343557355fc
line wrap: on
line diff
--- a/src/elab_env.sml	Sat Aug 16 12:35:46 2008 -0400
+++ b/src/elab_env.sml	Sat Aug 16 14:32:18 2008 -0400
@@ -31,6 +31,7 @@
 
 structure U = ElabUtil
 
+structure IS = IntBinarySet
 structure IM = IntBinaryMap
 structure SM = BinaryMapFn(struct
                            type ord_key = string
@@ -61,6 +62,22 @@
 
 val lift = liftConInCon 0
 
+val liftExpInExp =
+    U.Exp.mapB {kind = fn k => k,
+                con = fn _ => fn c => c,
+                exp = fn bound => fn e =>
+                                     case e of
+                                         ERel xn =>
+                                         if xn < bound then
+                                             e
+                                         else
+                                             ERel (xn + 1)
+                                       | _ => e,
+                bind = fn (bound, U.Exp.RelE _) => bound + 1
+                        | (bound, _) => bound}
+
+
+val liftExp = liftExpInExp 0
 
 (* Back to environments *)
 
@@ -75,6 +92,61 @@
 
 type datatyp = string list * (string * con option) IM.map
 
+datatype class_name =
+         ClNamed of int
+       | ClProj of int * string list * string
+
+structure CK = struct
+type ord_key = class_name
+open Order
+fun compare x =
+    case x of
+        (ClNamed n1, ClNamed n2) => Int.compare (n1, n2)
+      | (ClNamed _, _) => LESS
+      | (_, ClNamed _) => GREATER
+
+      | (ClProj (m1, ms1, x1), ClProj (m2, ms2, x2)) =>
+        join (Int.compare (m1, m2),
+              fn () => join (joinL String.compare (ms1, ms2),
+                             fn () => String.compare (x1, x2)))
+end
+
+structure CM = BinaryMapFn(CK)
+
+datatype class_key =
+         CkNamed of int
+       | CkRel of int
+       | CkProj of int * string list * string
+
+structure KK = struct
+type ord_key = class_key
+open Order
+fun compare x =
+    case x of
+        (CkNamed n1, CkNamed n2) => Int.compare (n1, n2)
+      | (CkNamed _, _) => LESS
+      | (_, CkNamed _) => GREATER
+
+      | (CkRel n1, CkRel n2) => Int.compare (n1, n2)
+      | (CkRel _, _) => LESS
+      | (_, CkRel _) => GREATER
+
+      | (CkProj (m1, ms1, x1), CkProj (m2, ms2, x2)) =>
+        join (Int.compare (m1, m2),
+              fn () => join (joinL String.compare (ms1, ms2),
+                             fn () => String.compare (x1, x2)))
+end
+
+structure KM = BinaryMapFn(KK)
+
+type class = {
+     ground : exp KM.map
+}
+
+val empty_class = {
+    ground = KM.empty
+}
+
 type env = {
      renameC : kind var' SM.map,
      relC : (string * kind) list,
@@ -83,6 +155,8 @@
      datatypes : datatyp IM.map,
      constructors : (datatype_kind * int * string list * con option * int) SM.map,
 
+     classes : class CM.map,
+
      renameE : con var' SM.map,
      relE : (string * con) list,
      namedE : (string * con) IM.map,
@@ -112,6 +186,8 @@
     datatypes = IM.empty,
     constructors = SM.empty,
 
+    classes = CM.empty,
+
     renameE = SM.empty,
     relE = [],
     namedE = IM.empty,
@@ -123,6 +199,12 @@
     str = IM.empty
 }
 
+fun liftClassKey ck =
+    case ck of
+        CkNamed _ => ck
+      | CkRel n => CkRel (n + 1)
+      | CkProj _ => ck
+
 fun pushCRel (env : env) x k =
     let
         val renameC = SM.map (fn Rel' (n, k) => Rel' (n+1, k)
@@ -135,6 +217,13 @@
          datatypes = #datatypes env,
          constructors = #constructors env,
 
+         classes = CM.map (fn class => {
+                              ground = KM.foldli (fn (ck, e, km) =>
+                                                     KM.insert (km, liftClassKey ck, e))
+                                                 KM.empty (#ground class)
+                          })
+                          (#classes env),
+
          renameE = #renameE env,
          relE = map (fn (x, c) => (x, lift c)) (#relE env),
          namedE = IM.map (fn (x, c) => (x, lift c)) (#namedE env),
@@ -159,6 +248,8 @@
      datatypes = #datatypes env,
      constructors = #constructors env,
 
+     classes = #classes env,
+
      renameE = #renameE env,
      relE = #relE env,
      namedE = #namedE env,
@@ -203,6 +294,8 @@
                                   SM.insert (cmap, x, (dk, n', xs, to, n)))
                               (#constructors env) xncs,
 
+         classes = #classes env,
+
          renameE = #renameE env,
          relE = #relE env,
          namedE = #namedE env,
@@ -229,10 +322,77 @@
 fun datatypeArgs (xs, _) = xs
 fun constructors (_, dt) = IM.foldri (fn (n, (x, to), ls) => (x, n, to) :: ls) [] dt
 
+fun pushClass (env : env) n =
+    {renameC = #renameC env,
+     relC = #relC env,
+     namedC = #namedC env,
+
+     datatypes = #datatypes env,
+     constructors = #constructors env,
+
+     classes = CM.insert (#classes env, ClNamed n, {ground = KM.empty}),
+
+     renameE = #renameE env,
+     relE = #relE env,
+     namedE = #namedE env,
+
+     renameSgn = #renameSgn env,
+     sgn = #sgn env,
+
+     renameStr = #renameStr env,
+     str = #str env}    
+
+fun class_name_in (c, _) =
+    case c of
+        CNamed n => SOME (ClNamed n)
+      | CModProj x => SOME (ClProj x)
+      | _ => NONE
+
+fun class_key_in (c, _) =
+    case c of
+        CRel n => SOME (CkRel n)
+      | CNamed n => SOME (CkNamed n)
+      | CModProj x => SOME (CkProj x)
+      | _ => NONE
+
+fun class_pair_in (c, _) =
+    case c of
+        CApp (f, x) =>
+        (case (class_name_in f, class_key_in x) of
+             (SOME f, SOME x) => SOME (f, x)
+           | _ => NONE)
+      | _ => NONE
+
+fun resolveClass (env : env) c =
+    case class_pair_in c of
+        SOME (f, x) =>
+        (case CM.find (#classes env, f) of
+             NONE => NONE
+           | SOME class =>
+             case KM.find (#ground class, x) of
+                 NONE => NONE
+               | SOME e => SOME e)
+      | _ => NONE
+
 fun pushERel (env : env) x t =
     let
         val renameE = SM.map (fn Rel' (n, t) => Rel' (n+1, t)
                                | x => x) (#renameE env)
+
+        val classes = CM.map (fn class => {
+                                 ground = KM.map liftExp (#ground class)
+                             }) (#classes env)
+        val classes = case class_pair_in t of
+                          NONE => classes
+                        | SOME (f, x) =>
+                          let
+                              val class = Option.getOpt (CM.find (classes, f), empty_class)
+                              val class = {
+                                  ground = KM.insert (#ground class, x, (ERel 0, #2 t))
+                              }
+                          in
+                              CM.insert (classes, f, class)
+                          end
     in
         {renameC = #renameC env,
          relC = #relC env,
@@ -241,6 +401,8 @@
          datatypes = #datatypes env,
          constructors = #constructors env,
 
+         classes = classes,
+
          renameE = SM.insert (renameE, x, Rel' (0, t)),
          relE = (x, t) :: #relE env,
          namedE = #namedE env,
@@ -257,22 +419,39 @@
     handle Subscript => raise UnboundRel n
 
 fun pushENamedAs (env : env) x n t =
-    {renameC = #renameC env,
-     relC = #relC env,
-     namedC = #namedC env,
+    let
+        val classes = #classes env
+        val classes = case class_pair_in t of
+                          NONE => classes
+                        | SOME (f, x) =>
+                          let
+                              val class = Option.getOpt (CM.find (classes, f), empty_class)
+                              val class = {
+                                  ground = KM.insert (#ground class, x, (ENamed n, #2 t))
+                              }
+                          in
+                              CM.insert (classes, f, class)
+                          end
+    in
+        {renameC = #renameC env,
+         relC = #relC env,
+         namedC = #namedC env,
 
-     datatypes = #datatypes env,
-     constructors = #constructors env,
+         datatypes = #datatypes env,
+         constructors = #constructors env,
 
-     renameE = SM.insert (#renameE env, x, Named' (n, t)),
-     relE = #relE env,
-     namedE = IM.insert (#namedE env, n, (x, t)),
+         classes = classes,
 
-     renameSgn = #renameSgn env,
-     sgn = #sgn env,
-     
-     renameStr = #renameStr env,
-     str = #str env}
+         renameE = SM.insert (#renameE env, x, Named' (n, t)),
+         relE = #relE env,
+         namedE = IM.insert (#namedE env, n, (x, t)),
+
+         renameSgn = #renameSgn env,
+         sgn = #sgn env,
+         
+         renameStr = #renameStr env,
+         str = #str env}
+    end
 
 fun pushENamed env x t =
     let
@@ -301,6 +480,8 @@
      datatypes = #datatypes env,
      constructors = #constructors env,
 
+     classes = #classes env,
+
      renameE = #renameE env,
      relE = #relE env,
      namedE = #namedE env,
@@ -326,32 +507,6 @@
 
 fun lookupSgn (env : env) x = SM.find (#renameSgn env, x)
 
-fun pushStrNamedAs (env : env) x n sgis =
-    {renameC = #renameC env,
-     relC = #relC env,
-     namedC = #namedC env,
-
-     datatypes = #datatypes env,
-     constructors = #constructors env,
-
-     renameE = #renameE env,
-     relE = #relE env,
-     namedE = #namedE env,
-
-     renameSgn = #renameSgn env,
-     sgn = #sgn env,
-
-     renameStr = SM.insert (#renameStr env, x, (n, sgis)),
-     str = IM.insert (#str env, n, (x, sgis))}
-
-fun pushStrNamed env x sgis =
-    let
-        val n = !namedCounter
-    in
-        namedCounter := n + 1;
-        (pushStrNamedAs env x n sgis, n)
-    end
-
 fun lookupStrNamed (env : env) n =
     case IM.find (#str env, n) of
         NONE => raise UnboundNamed n
@@ -359,57 +514,6 @@
 
 fun lookupStr (env : env) x = SM.find (#renameStr env, x)
 
-fun sgiBinds env (sgi, loc) =
-    case sgi of
-        SgiConAbs (x, n, k) => pushCNamedAs env x n k NONE
-      | SgiCon (x, n, k, c) => pushCNamedAs env x n k (SOME c)
-      | SgiDatatype (x, n, xs, xncs) =>
-        let
-            val env = pushCNamedAs env x n (KType, loc) NONE
-        in
-            foldl (fn ((x', n', to), env) =>
-                      let
-                          val t =
-                              case to of
-                                  NONE => (CNamed n, loc)
-                                | SOME t => (TFun (t, (CNamed n, loc)), loc)
-
-                          val k = (KType, loc)
-                          val t = foldr (fn (x, t) => (TCFun (Explicit, x, k, t), loc)) t xs
-                      in
-                          pushENamedAs env x' n' t
-                      end)
-            env xncs
-        end
-      | SgiDatatypeImp (x, n, m1, ms, x', xs, xncs) =>
-        let
-            val env = pushCNamedAs env x n (KType, loc) (SOME (CModProj (m1, ms, x'), loc))
-        in
-            foldl (fn ((x', n', to), env) =>
-                      let
-                          val t =
-                              case to of
-                                  NONE => (CNamed n, loc)
-                                | SOME t => (TFun (t, (CNamed n, loc)), loc)
-
-                          val k = (KType, loc)
-                          val t = foldr (fn (x, t) => (TCFun (Explicit, x, k, t), loc)) t xs
-                      in
-                          pushENamedAs env x' n' t
-                      end)
-            env xncs
-        end
-      | SgiVal (x, n, t) => pushENamedAs env x n t
-      | SgiStr (x, n, sgn) => pushStrNamedAs env x n sgn
-      | SgiSgn (x, n, sgn) => pushSgnNamedAs env x n sgn
-      | SgiConstraint _ => env
-
-      | SgiTable (tn, x, n, c) =>
-        let
-            val t = (CApp ((CModProj (tn, [], "table"), loc), c), loc)
-        in
-            pushENamedAs env x n t
-        end
 
 fun sgnSeek f sgis =
     let
@@ -439,6 +543,8 @@
                       | SgiStr (x, n, _) => seek (sgis, sgns, IM.insert (strs, n, x), cons)
                       | SgiConstraint _ => seek (sgis, sgns, strs, cons)
                       | SgiTable _ => seek (sgis, sgns, strs, cons)
+                      | SgiClassAbs (x, n) => seek (sgis, sgns, strs, IM.insert (cons, n, x))
+                      | SgiClass (x, n, _) => seek (sgis, sgns, strs, IM.insert (cons, n, x))
     in
         seek (sgis, IM.empty, IM.empty, IM.empty)
     end
@@ -500,17 +606,24 @@
              end)
       | _ => sgn
 
-fun sgnSubCon x =
-    ElabUtil.Con.map {kind = id,
-                      con = sgnS_con x}
-
 fun sgnSubSgn x =
     ElabUtil.Sgn.map {kind = id,
                       con = sgnS_con x,
                       sgn_item = id,
                       sgn = sgnS_sgn x}
 
-fun hnormSgn env (all as (sgn, loc)) =
+
+
+and projectSgn env {sgn, str, field} =
+    case #1 (hnormSgn env sgn) of
+        SgnConst sgis =>
+        (case sgnSeek (fn SgiSgn (x, _, sgn) => if x = field then SOME sgn else NONE | _ => NONE) sgis of
+             NONE => NONE
+           | SOME (sgn, subs) => SOME (sgnSubSgn (str, subs) sgn))
+      | SgnError => SOME (SgnError, ErrorMsg.dummySpan)
+      | _ => NONE
+
+and hnormSgn env (all as (sgn, loc)) =
     case sgn of
         SgnError => all
       | SgnVar n => hnormSgn env (#2 (lookupSgnNamed env n))
@@ -547,14 +660,117 @@
             end
           | _ => raise Fail "ElabEnv.hnormSgn: Can't reduce 'where' [2]"
 
-and projectSgn env {sgn, str, field} =
+fun enrichClasses env classes (m1, ms) sgn =
     case #1 (hnormSgn env sgn) of
         SgnConst sgis =>
-        (case sgnSeek (fn SgiSgn (x, _, sgn) => if x = field then SOME sgn else NONE | _ => NONE) sgis of
-             NONE => NONE
-           | SOME (sgn, subs) => SOME (sgnSubSgn (str, subs) sgn))
-      | SgnError => SOME (SgnError, ErrorMsg.dummySpan)
-      | _ => NONE
+        let
+            val (classes, _) =
+                foldl (fn (sgi, (classes, newClasses)) =>
+                          let
+                              fun found (x, n) =
+                                  (CM.insert (classes,
+                                              ClProj (m1, ms, x),
+                                              empty_class),
+                                   IS.add (newClasses, n))
+                          in
+                              case #1 sgi of
+                                  SgiClassAbs xn => found xn
+                                | SgiClass (x, n, _) => found (x, n)
+                                | _ => (classes, newClasses)
+                          end)
+                (classes, IS.empty) sgis
+        in
+            classes
+        end
+      | _ => classes
+
+fun pushStrNamedAs (env : env) x n sgn =
+    {renameC = #renameC env,
+     relC = #relC env,
+     namedC = #namedC env,
+
+     datatypes = #datatypes env,
+     constructors = #constructors env,
+
+     classes = enrichClasses env (#classes env) (n, []) sgn,
+
+     renameE = #renameE env,
+     relE = #relE env,
+     namedE = #namedE env,
+
+     renameSgn = #renameSgn env,
+     sgn = #sgn env,
+
+     renameStr = SM.insert (#renameStr env, x, (n, sgn)),
+     str = IM.insert (#str env, n, (x, sgn))}
+
+fun pushStrNamed env x sgn =
+    let
+        val n = !namedCounter
+    in
+        namedCounter := n + 1;
+        (pushStrNamedAs env x n sgn, n)
+    end
+
+fun sgiBinds env (sgi, loc) =
+    case sgi of
+        SgiConAbs (x, n, k) => pushCNamedAs env x n k NONE
+      | SgiCon (x, n, k, c) => pushCNamedAs env x n k (SOME c)
+      | SgiDatatype (x, n, xs, xncs) =>
+        let
+            val env = pushCNamedAs env x n (KType, loc) NONE
+        in
+            foldl (fn ((x', n', to), env) =>
+                      let
+                          val t =
+                              case to of
+                                  NONE => (CNamed n, loc)
+                                | SOME t => (TFun (t, (CNamed n, loc)), loc)
+
+                          val k = (KType, loc)
+                          val t = foldr (fn (x, t) => (TCFun (Explicit, x, k, t), loc)) t xs
+                      in
+                          pushENamedAs env x' n' t
+                      end)
+            env xncs
+        end
+      | SgiDatatypeImp (x, n, m1, ms, x', xs, xncs) =>
+        let
+            val env = pushCNamedAs env x n (KType, loc) (SOME (CModProj (m1, ms, x'), loc))
+        in
+            foldl (fn ((x', n', to), env) =>
+                      let
+                          val t =
+                              case to of
+                                  NONE => (CNamed n, loc)
+                                | SOME t => (TFun (t, (CNamed n, loc)), loc)
+
+                          val k = (KType, loc)
+                          val t = foldr (fn (x, t) => (TCFun (Explicit, x, k, t), loc)) t xs
+                      in
+                          pushENamedAs env x' n' t
+                      end)
+            env xncs
+        end
+      | SgiVal (x, n, t) => pushENamedAs env x n t
+      | SgiStr (x, n, sgn) => pushStrNamedAs env x n sgn
+      | SgiSgn (x, n, sgn) => pushSgnNamedAs env x n sgn
+      | SgiConstraint _ => env
+
+      | SgiTable (tn, x, n, c) =>
+        let
+            val t = (CApp ((CModProj (tn, [], "table"), loc), c), loc)
+        in
+            pushENamedAs env x n t
+        end
+
+      | SgiClassAbs (x, n) => pushCNamedAs env x n (KArrow ((KType, loc), (KType, loc)), loc) NONE
+      | SgiClass (x, n, c) => pushCNamedAs env x n (KArrow ((KType, loc), (KType, loc)), loc) (SOME c)
+        
+
+fun sgnSubCon x =
+    ElabUtil.Con.map {kind = id,
+                      con = sgnS_con x}
 
 fun projectStr env {sgn, str, field} =
     case #1 (hnormSgn env sgn) of
@@ -675,6 +891,8 @@
                   | SgiSgn (x, n, _) => seek (sgis, IM.insert (sgns, n, x), strs, cons, acc)
                   | SgiStr (x, n, _) => seek (sgis, sgns, IM.insert (strs, n, x), cons, acc)
                   | SgiTable _ => seek (sgis, sgns, strs, cons, acc)
+                  | SgiClassAbs (x, n) => seek (sgis, sgns, strs, IM.insert (cons, n, x), acc)
+                  | SgiClass (x, n, _) => seek (sgis, sgns, strs, IM.insert (cons, n, x), acc)
     in
         seek (sgis, IM.empty, IM.empty, IM.empty, [])
     end