diff src/elab_env.sml @ 675:43430b7190f4

Type class inclusions
author Adam Chlipala <adamc@hcoop.net>
date Thu, 26 Mar 2009 15:13:36 -0400
parents fab5998b840e
children 81573f62d6c3
line wrap: on
line diff
--- a/src/elab_env.sml	Thu Mar 26 14:37:31 2009 -0400
+++ b/src/elab_env.sml	Thu Mar 26 15:13:36 2009 -0400
@@ -233,11 +233,13 @@
 
 structure KM = BinaryMapFn(KK)
 
-type class = ((class_name * class_key) list * exp) KM.map
-val empty_class = KM.empty
+type class = {ground : ((class_name * class_key) list * exp) KM.map,
+              inclusions : exp CM.map}
+val empty_class = {ground = KM.empty,
+                   inclusions = CM.empty}
 
 fun printClasses cs = (print "Classes:\n";
-                       CM.appi (fn (cn, km) =>
+                       CM.appi (fn (cn, {ground = km, ...} : class) =>
                                    (print (cn2s cn ^ ":");
                                     KM.appi (fn (ck, _) => print (" " ^ ckn2s ck)) km;
                                     print "\n")) cs)
@@ -361,9 +363,10 @@
          constructors = #constructors env,
 
          classes = CM.map (fn class =>
-                              KM.foldli (fn (ck, e, km) =>
-                                            KM.insert (km, liftClassKey ck, e))
-                                        KM.empty class)
+                              {ground = KM.foldli (fn (ck, e, km) =>
+                                                      KM.insert (km, liftClassKey ck, e))
+                                                  KM.empty (#ground class),
+                               inclusions = #inclusions class})
                           (#classes env),
 
          renameE = SM.map (fn Rel' (n, c) => Rel' (n, lift c)
@@ -482,7 +485,7 @@
      datatypes = #datatypes env,
      constructors = #constructors env,
 
-     classes = CM.insert (#classes env, ClNamed n, KM.empty),
+     classes = CM.insert (#classes env, ClNamed n, empty_class),
 
      renameE = #renameE env,
      relE = #relE env,
@@ -565,12 +568,36 @@
               | SOME class =>
                 let
                     val loc = #2 c
-                              
+
+                    fun tryIncs () =
+                        let
+                            fun tryIncs fs =
+                                case fs of
+                                    [] => NONE
+                                  | (f', e') :: fs =>
+                                    case doPair (f', x) of
+                                        NONE => tryIncs fs
+                                      | SOME e =>
+                                        let
+                                            val e' = (ECApp (e', class_key_out loc x), loc)
+                                            val e' = (EApp (e', e), loc)
+                                        in
+                                            SOME e'
+                                        end
+                        in
+                            tryIncs (CM.listItemsi (#inclusions class))
+                        end
+
                     fun tryRules (k, args) =
                         let
                             val len = length args
+
+                            fun tryNext () =
+                                case k of
+                                    CkApp (k1, k2) => tryRules (k1, k2 :: args)
+                                  | _ => tryIncs ()
                         in
-                            case KM.find (class, (k, length args)) of
+                            case KM.find (#ground class, (k, length args)) of
                                 SOME (cs, e) =>
                                 let
                                     val es = map (fn (cn, ck) =>
@@ -585,7 +612,7 @@
                                                      end) cs
                                 in
                                     if List.exists (not o Option.isSome) es then
-                                        NONE
+                                        tryNext ()
                                     else
                                         let
                                             val e = foldl (fn (arg, e) => (ECApp (e, class_key_out loc arg), loc))
@@ -596,10 +623,7 @@
                                             SOME e
                                         end
                                 end
-                              | NONE =>
-                                case k of
-                                    CkApp (k1, k2) => tryRules (k1, k2 :: args)
-                                  | _ => NONE
+                              | NONE => tryNext ()
                         end
                 in
                     tryRules (x, [])
@@ -615,7 +639,9 @@
         val renameE = SM.map (fn Rel' (n, t) => Rel' (n+1, t)
                                | x => x) (#renameE env)
 
-        val classes = CM.map (KM.map (fn (ps, e) => (ps, liftExp e))) (#classes env)
+        val classes = CM.map (fn class =>
+                                 {ground = KM.map (fn (ps, e) => (ps, liftExp e)) (#ground class),
+                                  inclusions = #inclusions class}) (#classes env)
         val classes = case class_pair_in t of
                           NONE => classes
                         | SOME (f, x) =>
@@ -623,7 +649,8 @@
                               NONE => classes
                             | SOME class =>
                               let
-                                  val class = KM.insert (class, (x, 0), ([], (ERel 0, #2 t)))
+                                  val class = {ground = KM.insert (#ground class, (x, 0), ([], (ERel 0, #2 t))),
+                                               inclusions = #inclusions class}
                               in
                                   CM.insert (classes, f, class)
                               end
@@ -655,6 +682,10 @@
     (List.nth (#relE env, n))
     handle Subscript => raise UnboundRel n
 
+datatype rule =
+         Normal of int * (class_name * class_key) list * class_key
+       | Inclusion of class_name
+
 fun rule_in c =
     let
         fun quantifiers (c, nvars) =
@@ -675,7 +706,7 @@
                                 let
                                     fun dearg (ck, i) =
                                         if i >= nvars then
-                                            SOME (nvars, hyps, (cn, ck))
+                                            SOME (cn, Normal (nvars, hyps, ck))
                                         else case ck of
                                                  CkApp (ck, CkRel i') =>
                                                  if i' = i then
@@ -690,7 +721,13 @@
                     clauses (c, [])
                 end
     in
-        quantifiers (c, 0)
+        case #1 c of
+            TCFun (_, _, _, (TFun ((CApp (f1, (CRel 0, _)), _),
+                                   (CApp (f2, (CRel 0, _)), _)), _)) =>
+            (case (class_name_in f1, class_name_in f2) of
+                 (SOME f1, SOME f2) => SOME (f2, Inclusion f1)
+               | _ => NONE)
+          | _ => quantifiers (c, 0)
     end
 
 fun pushENamedAs (env : env) x n t =
@@ -698,12 +735,21 @@
         val classes = #classes env
         val classes = case rule_in t of
                           NONE => classes
-                        | SOME (nvars, hyps, (f, x)) =>
+                        | SOME (f, rule) =>
                           case CM.find (classes, f) of
                               NONE => classes
                             | SOME class =>
                               let
-                                  val class = KM.insert (class, (x, nvars), (hyps, (ENamed n, #2 t)))
+                                  val e = (ENamed n, #2 t)
+
+                                  val class =
+                                      case rule of
+                                          Normal (nvars, hyps, x) =>
+                                          {ground = KM.insert (#ground class, (x, nvars), (hyps, e)),
+                                           inclusions = #inclusions class}
+                                        | Inclusion f' =>
+                                          {ground = #ground class,
+                                           inclusions = CM.insert (#inclusions class, f', e)}
                               in
                                   CM.insert (classes, f, class)
                               end
@@ -1023,12 +1069,10 @@
                                 | SgiVal (x, n, c) =>
                                   (case rule_in c of
                                        NONE => default ()
-                                     | SOME (nvars, hyps, (cn, a)) =>
+                                     | SOME (cn, rule) =>
                                        let
+                                           val globalizeN = sgnS_class_name (m1, ms, fmap)
                                            val globalize = sgnS_class_key (m1, ms, fmap)
-                                           val ck = globalize a
-                                           val hyps = map (fn (n, k) => (sgnS_class_name (m1, ms, fmap) n,
-                                                                         globalize k)) hyps
 
                                            fun unravel c =
                                                case c of
@@ -1055,10 +1099,22 @@
                                                            NONE => classes
                                                          | SOME class =>
                                                            let
-                                                               val class = KM.insert (class, (ck, nvars),
-                                                                                      (hyps,
-                                                                                       (EModProj (m1, ms, x),
-                                                                                        #2 sgn)))
+                                                               val e = (EModProj (m1, ms, x),
+                                                                                     #2 sgn)
+
+                                                               val class =
+                                                                   case rule of
+                                                                       Normal (nvars, hyps, a) =>
+                                                                       {ground = 
+                                                                        KM.insert (#ground class, (globalize a, nvars),
+                                                                                   (map (fn (n, k) =>
+                                                                                            (globalizeN n,
+                                                                                             globalize k)) hyps, e)),
+                                                                        inclusions = #inclusions class}
+                                                                     | Inclusion f' =>
+                                                                       {ground = #ground class,
+                                                                        inclusions = CM.insert (#inclusions class,
+                                                                                                globalizeN f', e)}
                                                            in
                                                                CM.insert (classes, cn, class)
                                                            end
@@ -1077,9 +1133,21 @@
                                                            NONE => classes
                                                          | SOME class =>
                                                            let
-                                                               val class = KM.insert (class, (ck, nvars),
-                                                                                      (hyps,
-                                                                                       (EModProj (m1, ms, x), #2 sgn)))
+                                                               val e = (EModProj (m1, ms, x), #2 sgn)
+
+                                                               val class = 
+                                                                   case rule of
+                                                                       Normal (nvars, hyps, a) =>
+                                                                       {ground =
+                                                                        KM.insert (#ground class, (globalize a, nvars),
+                                                                                   (map (fn (n, k) =>
+                                                                                            (globalizeN n,
+                                                                                             globalize k)) hyps, e)),
+                                                                        inclusions = #inclusions class}
+                                                                     | Inclusion f' =>
+                                                                       {ground = #ground class,
+                                                                        inclusions = CM.insert (#inclusions class,
+                                                                                                globalizeN f', e)}
                                                            in
                                                                CM.insert (classes, cn, class)
                                                            end