changeset 1795:d28adceef22a

Allow type class instances with hypotheses via local ('let') definitions
author Adam Chlipala <adam@chlipala.net>
date Wed, 25 Jul 2012 14:04:59 -0400
parents 4671afac15af
children 0de0daab5fbb
files src/elab_env.sml src/elab_util.sig src/elab_util.sml src/source_print.sig
diffstat 4 files changed, 119 insertions(+), 72 deletions(-) [+]
line wrap: on
line diff
--- a/src/elab_env.sml	Wed Jul 25 08:20:15 2012 -0400
+++ b/src/elab_env.sml	Wed Jul 25 14:04:59 2012 -0400
@@ -163,6 +163,22 @@
                         | ((xn, rep), U.Exp.RelC _) => (xn, liftConInExp 0 rep)
                         | (ctx, _) => ctx}
 
+val openCon =
+    U.Con.existsB {kind = fn ((nk, _), k) =>
+                             case k of
+                                 KRel n => n >= nk
+                               | _ => false,
+                   con = fn ((_, nc), c) =>
+                            case c of
+                                CRel n => n >= nc
+                              | _ => false,
+                   bind = fn (all as (nk, nc), b) =>
+                             case b of
+                                 U.Con.RelK _ => (nk+1, nc)
+                               | U.Con.RelC _ => (nk, nc+1)
+                               | _ => all}
+    (0, 0)
+
 (* Back to environments *)
 
 datatype 'a var' =
@@ -208,10 +224,12 @@
 structure CS = BinarySetFn(CK)
 structure CM = BinaryMapFn(CK)
 
-type class = {ground : (con * exp) list,
-              rules : (int * con list * con * exp) list}
-val empty_class = {ground = [],
-                   rules = []}
+type rules = (int * con list * con * exp) list
+
+type class = {closedRules : rules,
+              openRules : rules}
+val empty_class = {closedRules = [],
+                   openRules = []}
 
 type env = {
      renameK : int SM.map,
@@ -286,11 +304,13 @@
          datatypes = #datatypes env,
          constructors = #constructors env,
 
-         classes = CM.map (fn cl => {ground = map (fn (c, e) =>
-                                                      (liftKindInCon 0 c,
-                                                       e))
-                                                  (#ground cl),
-                                     rules = #rules cl})
+         classes = CM.map (fn cl => {closedRules = #closedRules cl,
+                                     openRules = map (fn (nvs, cs, c, e) =>
+                                                         (nvs,
+                                                          map (liftKindInCon 0) cs,
+                                                          liftKindInCon 0 c,
+                                                          liftKindInExp 0 e))
+                                                     (#openRules cl)})
                           (#classes env),
 
          renameE = SM.map (fn Rel' (n, c) => Rel' (n, liftKindInCon 0 c)
@@ -328,11 +348,13 @@
          constructors = #constructors env,
 
          classes = CM.map (fn class =>
-                              {ground = map (fn (c, e) =>
-                                                (liftConInCon 0 c,
-                                                 e))
-                                            (#ground class),
-                               rules = #rules class})
+                              {closedRules = #closedRules class,
+                               openRules = map (fn (nvs, cs, c, e) =>
+                                                (nvs,
+                                                 map (liftConInCon 0) cs,
+                                                 liftConInCon 0 c,
+                                                 liftConInExp 0 e))
+                                            (#openRules class)})
                           (#classes env),
 
          renameE = SM.map (fn Rel' (n, c) => Rel' (n, lift c)
@@ -441,10 +463,9 @@
 fun constructors (_, dt) = IM.foldri (fn (n, (x, to), ls) => (x, n, to) :: ls) [] dt
 
 fun listClasses (env : env) =
-    map (fn (cn, {ground, rules}) =>
+    map (fn (cn, {closedRules, openRules}) =>
             (class_name_out cn,
-             ground
-             @ map (fn (nvs, cs, c, e) =>
+             map (fn (nvs, cs, c, e) =>
                        let
                            val loc = #2 c
                            val c = foldr (fn (c', c) => (TFun (c', c), loc)) c cs
@@ -455,7 +476,7 @@
                                                    c (List.tabulate (nvs, fn _ => ()))
                        in
                            (c, e)
-                       end) rules)) (CM.listItemsi (#classes env))
+                       end) (closedRules @ openRules))) (CM.listItemsi (#classes env))
 
 fun pushClass (env : env) n =
     {renameK = #renameK env,
@@ -653,6 +674,8 @@
                                      CRel n =>
                                      if n < d then
                                          c
+                                     else if n - d >= length rs then
+                                         CRel (n - d)
                                      else
                                          #1 (List.nth (rs, n - d))
                                    | _ => c,
@@ -729,7 +752,7 @@
                                 case rules of
                                     [] => notFound ()
                                   | (nRs, cs, c', e) :: rules' =>
-                                    case tryUnify hnorm nRs (c, c') of
+                                     case tryUnify hnorm nRs (c, c') of
                                         NONE => tryRules rules'
                                       | SOME rs =>
                                         let
@@ -749,18 +772,8 @@
                                                     SOME e
                                                 end
                                         end
-
-                            fun rules () = tryRules (#rules class)
-  
-                            fun tryGrounds ces =
-                                case ces of
-                                    [] => rules ()
-                                  | (c', e) :: ces' =>
-                                    case tryUnify hnorm 0 (c, c') of
-                                        NONE => tryGrounds ces'
-                                      | SOME _ => SOME e
                         in
-                            tryGrounds (#ground class)
+                            tryRules (#openRules class @ #closedRules class)
                         end
             in
                 if startsWithUnif c then
@@ -800,23 +813,55 @@
         resolve true
     end
 
+fun rule_in c =
+    let
+        fun quantifiers (c, nvars) =
+            case #1 c of
+                CUnif (_, _, _, _, ref (Known c)) => quantifiers (c, nvars)
+              | TCFun (_, _, _, c) => quantifiers (c, nvars + 1)
+              | _ =>
+                let
+                    fun clauses (c, hyps) =
+                        case #1 c of
+                            TFun (hyp, c) =>
+                            (case class_head_in hyp of
+                                 SOME _ => clauses (c, hyp :: hyps)
+                               | NONE => NONE)
+                          | _ =>
+                            case class_head_in c of
+                                NONE => NONE
+                              | SOME f => SOME (f, nvars, rev hyps, c)
+                in
+                    clauses (c, [])
+                end
+    in
+        quantifiers (c, 0)
+    end
+
 fun pushERel (env : env) x t =
     let
         val renameE = SM.map (fn Rel' (n, t) => Rel' (n+1, t)
                                | x => x) (#renameE env)
 
         val classes = CM.map (fn class =>
-                                 {ground = map (fn (c, e) => (c, liftExp e)) (#ground class),
-                                  rules = #rules class}) (#classes env)
-        val classes = case class_head_in t of
+                                 {openRules = map (fn (nvs, cs, c, e) => (nvs, cs, c, liftExp e)) (#openRules class),
+                                  closedRules = #closedRules class}) (#classes env)
+        val classes = case rule_in t of
                           NONE => classes
-                        | SOME f =>
+                        | SOME (f, nvs, cs, c) =>
                           case CM.find (classes, f) of
                               NONE => classes
                             | SOME class =>
                               let
-                                  val class = {ground = (t, (ERel 0, #2 t)) :: #ground class,
-                                               rules = #rules class}
+                                  val rule = (nvs, cs, c, (ERel 0, #2 t))
+
+                                  val class =
+                                      if openCon t then
+                                          {openRules = rule :: #openRules class,
+                                           closedRules = #closedRules class}
+                                      else
+                                          {closedRules = rule :: #closedRules class,
+                                           openRules = #openRules class}
                               in
                                   CM.insert (classes, f, class)
                               end
@@ -848,30 +893,6 @@
     (List.nth (#relE env, n))
     handle Subscript => raise UnboundRel n
 
-fun rule_in c =
-    let
-        fun quantifiers (c, nvars) =
-            case #1 c of
-                TCFun (_, _, _, c) => quantifiers (c, nvars + 1)
-              | _ =>
-                let
-                    fun clauses (c, hyps) =
-                        case #1 c of
-                            TFun (hyp, c) =>
-                            (case class_head_in hyp of
-                                 SOME _ => clauses (c, hyp :: hyps)
-                               | NONE => NONE)
-                          | _ =>
-                            case class_head_in c of
-                                NONE => NONE
-                              | SOME f => SOME (f, nvars, rev hyps, c)
-                in
-                    clauses (c, [])
-                end
-    in
-        quantifiers (c, 0)
-    end
-
 fun pushENamedAs (env : env) x n t =
     let
         val classes = #classes env
@@ -885,8 +906,8 @@
                                   val e = (ENamed n, #2 t)
 
                                   val class =
-                                      {ground = #ground class,
-                                       rules = (nvs, cs, c, e) :: #rules class}
+                                      {openRules = #openRules class,
+                                       closedRules = (nvs, cs, c, e) :: #closedRules class}
                               in
                                   CM.insert (classes, f, class)
                               end
@@ -1210,11 +1231,11 @@
                                                                val e = (EModProj (m1, ms, x), #2 sgn)
 
                                                                val class =
-                                                                   {ground = #ground class,
-                                                                    rules = (nvs,
-                                                                             map globalize cs,
-                                                                             globalize c,
-                                                                             e) :: #rules class}
+                                                                   {openRules = #openRules class,
+                                                                    closedRules = (nvs,
+                                                                                   map globalize cs,
+                                                                                   globalize c,
+                                                                                   e) :: #closedRules class}
                                                            in
                                                                CM.insert (classes, cn, class)
                                                            end
@@ -1236,11 +1257,11 @@
                                                                val e = (EModProj (m1, ms, x), #2 sgn)
 
                                                                val class = 
-                                                                   {ground = #ground class,
-                                                                    rules = (nvs,
-                                                                             map globalize cs,
-                                                                             globalize c,
-                                                                             e) :: #rules class}
+                                                                   {openRules = #openRules class,
+                                                                    closedRules = (nvs,
+                                                                                   map globalize cs,
+                                                                                   globalize c,
+                                                                                   e) :: #closedRules class}
                                                            in
                                                                CM.insert (classes, cn, class)
                                                            end
--- a/src/elab_util.sig	Wed Jul 25 08:20:15 2012 -0400
+++ b/src/elab_util.sig	Wed Jul 25 14:04:59 2012 -0400
@@ -112,6 +112,11 @@
     val exists : {kind : Elab.kind' -> bool,
                   con : Elab.con' -> bool,
                   exp : Elab.exp' -> bool} -> Elab.exp -> bool
+    val existsB : {kind : 'context * Elab.kind' -> bool,
+                   con : 'context * Elab.con' -> bool,
+                   exp : 'context * Elab.exp' -> bool,
+                   bind : 'context * binder -> 'context}
+                  -> 'context -> Elab.exp -> bool
 
     val foldB : {kind : 'context * Elab.kind' * 'state -> 'state,
                  con : 'context * Elab.con' * 'state -> 'state,
--- a/src/elab_util.sml	Wed Jul 25 08:20:15 2012 -0400
+++ b/src/elab_util.sml	Wed Jul 25 14:04:59 2012 -0400
@@ -568,6 +568,26 @@
               exp = fn () => fe,
               bind = fn ((), _) => ()} ()
 
+fun existsB {kind, con, exp, bind} ctx e =
+    case mapfoldB {kind = fn ctx => fn k => fn () =>
+                                               if kind (ctx, k) then
+                                                   S.Return ()
+                                               else
+                                                   S.Continue (k, ()),
+                   con = fn ctx => fn c => fn () =>
+                                              if con (ctx, c) then
+                                                  S.Return ()
+                                              else
+                                                  S.Continue (c, ()),
+                   exp = fn ctx => fn e => fn () =>
+                                              if exp (ctx, e) then
+                                                  S.Return ()
+                                              else
+                                                  S.Continue (e, ()),
+                   bind = bind} ctx e () of
+        S.Return _ => true
+      | S.Continue _ => false
+
 fun exists {kind, con, exp} k =
     case mapfold {kind = fn k => fn () =>
                                     if kind k then
--- a/src/source_print.sig	Wed Jul 25 08:20:15 2012 -0400
+++ b/src/source_print.sig	Wed Jul 25 14:04:59 2012 -0400
@@ -33,6 +33,7 @@
     val p_con : Source.con Print.printer
     val p_exp : Source.exp Print.printer
     val p_decl : Source.decl Print.printer
+    val p_edecl : Source.edecl Print.printer
     val p_sgn_item : Source.sgn_item Print.printer
     val p_str : Source.str Print.printer
     val p_file : Source.file Print.printer