diff src/elab_env.sml @ 674:fab5998b840e

Type class reductions, but no inclusions yet
author Adam Chlipala <adamc@hcoop.net>
date Thu, 26 Mar 2009 14:37:31 -0400
parents 588b9d16b00a
children 43430b7190f4
line wrap: on
line diff
--- a/src/elab_env.sml	Tue Mar 24 15:35:46 2009 -0400
+++ b/src/elab_env.sml	Thu Mar 26 14:37:31 2009 -0400
@@ -1,4 +1,4 @@
-(* Copyright (c) 2008, Adam Chlipala
+(* Copyright (c) 2008-2009, Adam Chlipala
  * All rights reserved.
  *
  * Redistribution and use in source and binary forms, with or without
@@ -197,12 +197,16 @@
       | CkProj (m, ms, x) => "Proj(" ^ Int.toString m ^ "," ^ String.concatWith "," ms ^ "," ^ x ^ ")"
       | CkApp (ck1, ck2) => "App(" ^ ck2s ck1 ^ ", " ^ ck2s ck2 ^ ")"
 
+type class_key_n = class_key * int
+
+fun ckn2s (ck, n) = ck2s ck ^ "[" ^ Int.toString n ^ "]"
+
 fun cp2s (cn, ck) = "(" ^ cn2s cn ^ "," ^ ck2s ck ^ ")"
 
 structure KK = struct
-type ord_key = class_key
+type ord_key = class_key_n
 open Order
-fun compare x =
+fun compare' x =
     case x of
         (CkNamed n1, CkNamed n2) => Int.compare (n1, n2)
       | (CkNamed _, _) => LESS
@@ -220,24 +224,22 @@
       | (_, CkProj _) => GREATER
 
       | (CkApp (f1, x1), CkApp (f2, x2)) =>
-        join (compare (f1, f2),
-              fn () => compare (x1, x2))
+        join (compare' (f1, f2),
+              fn () => compare' (x1, x2))
+fun compare ((k1, n1), (k2, n2)) =
+    join (Int.compare (n1, n2),
+       fn () => compare' (k1, k2))
 end
 
 structure KM = BinaryMapFn(KK)
 
-type class = {
-     ground : exp KM.map
-}
-
-val empty_class = {
-    ground = KM.empty
-}
+type class = ((class_name * class_key) list * exp) KM.map
+val empty_class = KM.empty
 
 fun printClasses cs = (print "Classes:\n";
-                       CM.appi (fn (cn, {ground = km}) =>
+                       CM.appi (fn (cn, km) =>
                                    (print (cn2s cn ^ ":");
-                                    KM.appi (fn (ck, _) => print (" " ^ ck2s ck)) km;
+                                    KM.appi (fn (ck, _) => print (" " ^ ckn2s ck)) km;
                                     print "\n")) cs)
 
 type env = {
@@ -298,12 +300,14 @@
     str = IM.empty
 }
 
-fun liftClassKey ck =
+fun liftClassKey' ck =
     case ck of
         CkNamed _ => ck
       | CkRel n => CkRel (n + 1)
       | CkProj _ => ck
-      | CkApp (ck1, ck2) => CkApp (liftClassKey ck1, liftClassKey ck2)
+      | CkApp (ck1, ck2) => CkApp (liftClassKey' ck1, liftClassKey' ck2)
+
+fun liftClassKey (ck, n) = (liftClassKey' ck, n)
 
 fun pushKRel (env : env) x =
     let
@@ -356,11 +360,10 @@
          datatypes = #datatypes env,
          constructors = #constructors env,
 
-         classes = CM.map (fn class => {
-                              ground = KM.foldli (fn (ck, e, km) =>
-                                                     KM.insert (km, liftClassKey ck, e))
-                                                 KM.empty (#ground class)
-                          })
+         classes = CM.map (fn class =>
+                              KM.foldli (fn (ck, e, km) =>
+                                            KM.insert (km, liftClassKey ck, e))
+                                        KM.empty class)
                           (#classes env),
 
          renameE = SM.map (fn Rel' (n, c) => Rel' (n, lift c)
@@ -479,7 +482,7 @@
      datatypes = #datatypes env,
      constructors = #constructors env,
 
-     classes = CM.insert (#classes env, ClNamed n, {ground = KM.empty}),
+     classes = CM.insert (#classes env, ClNamed n, KM.empty),
 
      renameE = #renameE env,
      relE = #relE env,
@@ -518,6 +521,18 @@
            | _ => NONE)
       | _ => NONE
 
+fun class_key_out loc =
+    let
+        fun cko k =
+            case k of
+                CkRel n => (CRel n, loc)
+              | CkNamed n => (CNamed n, loc)
+              | CkProj x => (CModProj x, loc)
+              | CkApp (k1, k2) => (CApp (cko k1, cko k2), loc)
+    in
+        cko
+    end
+
 fun class_pair_in (c, _) =
     case c of
         CApp (f, x) =>
@@ -527,25 +542,80 @@
       | CUnif (_, _, _, ref (SOME c)) => class_pair_in c
       | _ => NONE
 
+fun sub_class_key (n, c) =
+    let
+        fun csk k =
+            case k of
+                CkRel n' => if n' = n then
+                                c
+                            else
+                                k
+              | CkNamed _ => k
+              | CkProj _ => k
+              | CkApp (k1, k2) => CkApp (csk k1, csk k2)
+    in
+        csk
+    end
+
 fun resolveClass (env : env) c =
-    case class_pair_in c of
-        SOME (f, x) =>
-        (case CM.find (#classes env, f) of
-             NONE => NONE
-           | SOME class =>
-             case KM.find (#ground class, x) of
-                 NONE => NONE
-               | SOME e => SOME e)
-      | _ => NONE
+    let
+        fun doPair (f, x) =
+            case CM.find (#classes env, f) of
+                NONE => NONE
+              | SOME class =>
+                let
+                    val loc = #2 c
+                              
+                    fun tryRules (k, args) =
+                        let
+                            val len = length args
+                        in
+                            case KM.find (class, (k, length args)) of
+                                SOME (cs, e) =>
+                                let
+                                    val es = map (fn (cn, ck) =>
+                                                     let
+                                                         val ck = ListUtil.foldli (fn (i, arg, ck) =>
+                                                                                      sub_class_key (len - i - 1,
+                                                                                                     arg)
+                                                                                                    ck)
+                                                                                  ck args
+                                                     in
+                                                         doPair (cn, ck)
+                                                     end) cs
+                                in
+                                    if List.exists (not o Option.isSome) es then
+                                        NONE
+                                    else
+                                        let
+                                            val e = foldl (fn (arg, e) => (ECApp (e, class_key_out loc arg), loc))
+                                                          e args
+                                            val e = foldr (fn (pf, e) => (EApp (e, pf), loc))
+                                                          e (List.mapPartial (fn x => x) es)
+                                        in
+                                            SOME e
+                                        end
+                                end
+                              | NONE =>
+                                case k of
+                                    CkApp (k1, k2) => tryRules (k1, k2 :: args)
+                                  | _ => NONE
+                        end
+                in
+                    tryRules (x, [])
+                end
+    in
+        case class_pair_in c of
+            SOME p => doPair p
+          | _ => NONE
+    end
 
 fun pushERel (env : env) x t =
     let
         val renameE = SM.map (fn Rel' (n, t) => Rel' (n+1, t)
                                | x => x) (#renameE env)
 
-        val classes = CM.map (fn class => {
-                                 ground = KM.map liftExp (#ground class)
-                             }) (#classes env)
+        val classes = CM.map (KM.map (fn (ps, e) => (ps, liftExp e))) (#classes env)
         val classes = case class_pair_in t of
                           NONE => classes
                         | SOME (f, x) =>
@@ -553,9 +623,7 @@
                               NONE => classes
                             | SOME class =>
                               let
-                                  val class = {
-                                      ground = KM.insert (#ground class, x, (ERel 0, #2 t))
-                                  }
+                                  val class = KM.insert (class, (x, 0), ([], (ERel 0, #2 t)))
                               in
                                   CM.insert (classes, f, class)
                               end
@@ -587,19 +655,55 @@
     (List.nth (#relE env, n))
     handle Subscript => raise UnboundRel n
 
+fun rule_in c =
+    let
+        fun quantifiers (c, nvars) =
+            case #1 c of
+                TCFun (_, _, _, c) => quantifiers (c, nvars + 1)
+              | _ =>
+                let
+                    fun clauses (c, hyps) =
+                        case #1 c of
+                            TFun (hyp, c) =>
+                            (case class_pair_in hyp of
+                                 NONE => NONE
+                               | SOME p => clauses (c, p :: hyps))
+                          | _ =>
+                            case class_pair_in c of
+                                NONE => NONE
+                              | SOME (cn, ck) =>
+                                let
+                                    fun dearg (ck, i) =
+                                        if i >= nvars then
+                                            SOME (nvars, hyps, (cn, ck))
+                                        else case ck of
+                                                 CkApp (ck, CkRel i') =>
+                                                 if i' = i then
+                                                     dearg (ck, i + 1)
+                                                 else
+                                                     NONE
+                                               | _ => NONE
+                                in
+                                    dearg (ck, 0)
+                                end
+                in
+                    clauses (c, [])
+                end
+    in
+        quantifiers (c, 0)
+    end
+
 fun pushENamedAs (env : env) x n t =
     let
         val classes = #classes env
-        val classes = case class_pair_in t of
+        val classes = case rule_in t of
                           NONE => classes
-                        | SOME (f, x) =>
+                        | SOME (nvars, hyps, (f, x)) =>
                           case CM.find (classes, f) of
                               NONE => classes
                             | SOME class =>
                               let
-                                  val class = {
-                                      ground = KM.insert (#ground class, x, (ENamed n, #2 t))
-                                  }
+                                  val class = KM.insert (class, (x, nvars), (hyps, (ENamed n, #2 t)))
                               in
                                   CM.insert (classes, f, class)
                               end
@@ -784,6 +888,31 @@
                                (sgnS_con' arg (#1 c2), #2 c2))
       | _ => c
 
+fun sgnS_class_name (arg as (m1, ms', (sgns, strs, cons))) nm =
+    case nm of
+        ClProj (m1, ms, x) =>
+        (case IM.find (strs, m1) of
+             NONE => nm
+           | SOME m1x => ClProj (m1, ms' @ m1x :: ms, x))
+      | ClNamed n =>
+        (case IM.find (cons, n) of
+             NONE => nm
+           | SOME nx => ClProj (m1, ms', nx))
+
+fun sgnS_class_key (arg as (m1, ms', (sgns, strs, cons))) k =
+    case k of
+        CkProj (m1, ms, x) =>
+        (case IM.find (strs, m1) of
+             NONE => k
+           | SOME m1x => CkProj (m1, ms' @ m1x :: ms, x))
+      | CkNamed n =>
+        (case IM.find (cons, n) of
+             NONE => k
+           | SOME nx => CkProj (m1, ms', nx))
+      | CkApp (k1, k2) => CkApp (sgnS_class_key arg k1,
+                                 sgnS_class_key arg k2)
+      | _ => k
+
 fun sgnS_sgn (str, (sgns, strs, cons)) sgn =
     case sgn of
         SgnProj (m1, ms, x) =>
@@ -891,38 +1020,45 @@
 
                                 | SgiClassAbs (x, n, _) => found (x, n)
                                 | SgiClass (x, n, _, _) => found (x, n)
-                                | SgiVal (x, n, (CApp (f, a), _)) =>
-                                  let
-                                      fun unravel c =
-                                          case #1 c of
-                                              CUnif (_, _, _, ref (SOME c)) => unravel c
-                                            | CNamed n =>
-                                              ((case lookupCNamed env n of
-                                                    (_, _, SOME c) => unravel c
-                                                  | _ => c)
-                                               handle UnboundNamed _ => c)
-                                            | _ => c
+                                | SgiVal (x, n, c) =>
+                                  (case rule_in c of
+                                       NONE => default ()
+                                     | SOME (nvars, hyps, (cn, a)) =>
+                                       let
+                                           val globalize = sgnS_class_key (m1, ms, fmap)
+                                           val ck = globalize a
+                                           val hyps = map (fn (n, k) => (sgnS_class_name (m1, ms, fmap) n,
+                                                                         globalize k)) hyps
 
-                                      val nc =
-                                          case f of
-                                              (CNamed f, _) => IM.find (newClasses, f)
-                                            | _ => NONE
-                                  in
-                                      case nc of
-                                          NONE =>
-                                          (case (class_name_in (unravel f),
-                                                 class_key_in (sgnS_con' (m1, ms, fmap) (#1 a), #2 a)) of
-                                               (SOME cn, SOME ck) =>
+                                           fun unravel c =
+                                               case c of
+                                                   ClNamed n =>
+                                                   ((case lookupCNamed env n of
+                                                         (_, _, SOME c') =>
+                                                         (case class_name_in c' of
+                                                              NONE => c
+                                                            | SOME k => unravel k)
+                                                       | _ => c)
+                                                    handle UnboundNamed _ => c)
+                                                 | _ => c
+
+                                           val nc =
+                                               case cn of
+                                                   ClNamed f => IM.find (newClasses, f)
+                                                 | _ => NONE
+                                       in
+                                           case nc of
+                                               NONE =>
                                                let
                                                    val classes =
                                                        case CM.find (classes, cn) of
                                                            NONE => classes
                                                          | SOME class =>
                                                            let
-                                                               val class = {
-                                                                   ground = KM.insert (#ground class, ck,
-                                                                                       (EModProj (m1, ms, x), #2 sgn))
-                                                               }
+                                                               val class = KM.insert (class, (ck, nvars),
+                                                                                      (hyps,
+                                                                                       (EModProj (m1, ms, x),
+                                                                                        #2 sgn)))
                                                            in
                                                                CM.insert (classes, cn, class)
                                                            end
@@ -932,34 +1068,28 @@
                                                     fmap,
                                                     env)
                                                end
-                                             | _ => default ())
-                                        | SOME fx =>
-                                          case class_key_in (sgnS_con' (m1, ms, fmap) (#1 a), #2 a) of
-                                              NONE => default ()
-                                            | SOME ck =>
-                                              let
-                                                  val cn = ClProj (m1, ms, fx)
+                                             | SOME fx =>
+                                               let
+                                                   val cn = ClProj (m1, ms, fx)
 
-                                                  val classes =
-                                                      case CM.find (classes, cn) of
-                                                          NONE => classes
-                                                        | SOME class =>
-                                                          let
-                                                              val class = {
-                                                                  ground = KM.insert (#ground class, ck,
-                                                                                      (EModProj (m1, ms, x), #2 sgn))
-                                                              }
-                                                          in
-                                                              CM.insert (classes, cn, class)
-                                                          end
-                                              in
-                                                  (classes,
-                                                   newClasses,
-                                                   fmap,
-                                                   env)
-                                              end
-                                  end
-                                | SgiVal _ => default ()
+                                                   val classes =
+                                                       case CM.find (classes, cn) of
+                                                           NONE => classes
+                                                         | SOME class =>
+                                                           let
+                                                               val class = KM.insert (class, (ck, nvars),
+                                                                                      (hyps,
+                                                                                       (EModProj (m1, ms, x), #2 sgn)))
+                                                           in
+                                                               CM.insert (classes, cn, class)
+                                                           end
+                                               in
+                                                   (classes,
+                                                    newClasses,
+                                                    fmap,
+                                                    env)
+                                               end
+                                       end)
                                 | _ => default ()
                           end)
                       (classes, IM.empty, (IM.empty, IM.empty, IM.empty), env) sgis