changeset 2201:1091227f535a

Unnest properly in presence of kind polymorphism
author Adam Chlipala <adam@chlipala.net>
date Sun, 20 Dec 2015 13:41:35 -0500
parents fc1c89627178
children 6fb9232ade99
files src/elab_util.sig src/elab_util.sml src/unnest.sml
diffstat 3 files changed, 168 insertions(+), 79 deletions(-) [+]
line wrap: on
line diff
--- a/src/elab_util.sig	Tue Dec 15 19:58:52 2015 -0500
+++ b/src/elab_util.sig	Sun Dec 20 13:41:35 2015 -0500
@@ -41,6 +41,9 @@
     val mapB : {kind : 'context -> Elab.kind' -> Elab.kind',
                 bind : 'context * string -> 'context}
                -> 'context -> (Elab.kind -> Elab.kind)
+    val foldB : {kind : 'context * Elab.kind' * 'state -> 'state,
+                 bind : 'context * string -> 'context}
+                -> 'context -> 'state -> Elab.kind -> 'state
 end
 
 structure Con : sig
--- a/src/elab_util.sml	Tue Dec 15 19:58:52 2015 -0500
+++ b/src/elab_util.sml	Sun Dec 20 13:41:35 2015 -0500
@@ -116,6 +116,12 @@
         S.Return _ => true
       | S.Continue _ => false
 
+fun foldB {kind, bind} ctx st k =
+    case mapfoldB {kind = fn ctx => fn k => fn st => S.Continue (k, kind (ctx, k, st)),
+                   bind = bind} ctx k st of
+        S.Continue (_, st) => st
+      | S.Return _ => raise Fail "ElabUtil.Kind.foldB: Impossible"
+
 end
 
 val mliftConInCon = ref (fn n : int => fn c : con => (raise Fail "You didn't set ElabUtil.mliftConInCon!") : con)
--- a/src/unnest.sml	Tue Dec 15 19:58:52 2015 -0500
+++ b/src/unnest.sml	Sun Dec 20 13:41:35 2015 -0500
@@ -65,44 +65,71 @@
                         | ((xn, rep), U.Exp.RelC _) => (xn, E.liftConInExp 0 rep)
                         | (ctx, _) => ctx}
 
-val fvsCon = U.Con.foldB {kind = fn (_, _, st) => st,
-                          con = fn (cb, c, cvs) =>
+val fvsKind = U.Kind.foldB {kind = fn (kb, k, kvs) =>
+                                      case k of
+                                          KRel n =>
+                                          if n >= kb then
+                                              IS.add (kvs, n - kb)
+                                          else
+                                              kvs
+                                        | _ => kvs,
+                          bind = fn (kb, b) => kb + 1}
+                         0 IS.empty
+
+val fvsCon = U.Con.foldB {kind = fn ((kb, _), k, st as (kvs, cvs)) =>
+                                    case k of
+                                        KRel n =>
+                                        if n >= kb then
+                                            (IS.add (kvs, n - kb), cvs)
+                                        else
+                                            st
+                                      | _ => st,
+                          con = fn ((_, cb), c, st as (kvs, cvs)) =>
                                    case c of
                                        CRel n =>
                                        if n >= cb then
-                                           IS.add (cvs, n - cb)
+                                           (kvs, IS.add (cvs, n - cb))
                                        else
-                                           cvs
-                                     | _ => cvs,
-                          bind = fn (cb, b) =>
+                                           st
+                                     | _ => st,
+                          bind = fn (ctx as (kb, cb), b) =>
                                     case b of
-                                        U.Con.RelC _ => cb + 1
-                                      | _ => cb}
-                         0 IS.empty
+                                        U.Con.RelK _ => (kb + 1, cb + 1)
+                                      | U.Con.RelC _ => (kb, cb + 1)
+                                      | _ => ctx}
+                         (0, 0) (IS.empty, IS.empty)
 
-fun fvsExp nr = U.Exp.foldB {kind = fn (_, _, st) => st,
-                             con = fn ((cb, eb), c, st as (cvs, evs)) =>
+fun fvsExp nr = U.Exp.foldB {kind = fn ((kb, _, _), k, st as (kvs, cvs, evs)) =>
+                                       case k of
+                                           KRel n =>
+                                           if n >= kb then
+                                               (IS.add (kvs, n - kb), cvs, evs)
+                                           else
+                                               st
+                                         | _ => st,
+                             con = fn ((kb, cb, eb), c, st as (kvs, cvs, evs)) =>
                                       case c of
                                           CRel n =>
                                           if n >= cb then
-                                              (IS.add (cvs, n - cb), evs)
+                                              (kvs, IS.add (cvs, n - cb), evs)
                                           else
                                               st
                                         | _ => st,
-                             exp = fn ((cb, eb), e, st as (cvs, evs)) =>
+                             exp = fn ((kb, cb, eb), e, st as (kvs, cvs, evs)) =>
                                       case e of
                                           ERel n =>
                                           if n >= eb then
-                                              (cvs, IS.add (evs, n - eb))
+                                              (kvs, cvs, IS.add (evs, n - eb))
                                           else
                                               st
                                         | _ => st,
-                             bind = fn (ctx as (cb, eb), b) =>
+                             bind = fn (ctx as (kb, cb, eb), b) =>
                                        case b of
-                                           U.Exp.RelC _ => (cb + 1, eb)
-                                         | U.Exp.RelE _ => (cb, eb + 1)
+                                           U.Exp.RelK _ => (kb + 1, cb, eb)
+                                         | U.Exp.RelC _ => (kb, cb + 1, eb)
+                                         | U.Exp.RelE _ => (kb, cb, eb + 1)
                                          | _ => ctx}
-                            (0, nr) (IS.empty, IS.empty)
+                            (0, 0, nr) (IS.empty, IS.empty, IS.empty)
 
 fun positionOf (x : int) ls =
     let
@@ -123,46 +150,62 @@
                                      ^ ")")
     end
 
-fun squishCon cfv =
-    U.Con.mapB {kind = fn _ => fn k => k,
-                con = fn cb => fn c =>
-                                  case c of
-                                      CRel n =>
-                                      if n >= cb then
-                                          CRel (positionOf (n - cb) cfv + cb)
-                                      else
-                                          c
-                                    | _ => c,
-                bind = fn (cb, b) =>
+fun squishCon (kfv, cfv) =
+    U.Con.mapB {kind = fn (kb, _) => fn k =>
+                                        case k of
+                                            KRel n =>
+                                            if n >= kb then
+                                                KRel (positionOf (n - kb) kfv + kb)
+                                            else
+                                                k
+                                          | _ => k,
+                con = fn (_, cb) => fn c =>
+                                       case c of
+                                           CRel n =>
+                                           if n >= cb then
+                                               CRel (positionOf (n - cb) cfv + cb)
+                                           else
+                                               c
+                                         | _ => c,
+                bind = fn (ctx as (kb, cb), b) =>
                           case b of
-                              U.Con.RelC _ => cb + 1
-                            | _ => cb}
-               0
+                              U.Con.RelK _ => (kb + 1, cb)
+                            | U.Con.RelC _ => (kb, cb + 1)
+                            | _ => ctx}
+               (0, 0)
 
-fun squishExp (nr, cfv, efv) =
-    U.Exp.mapB {kind = fn _ => fn k => k,
-                con = fn (cb, eb) => fn c =>
-                                        case c of
-                                            CRel n =>
-                                            if n >= cb then
-                                                CRel (positionOf (n - cb) cfv + cb)
-                                            else
-                                                c
-                                          | _ => c,
-                exp = fn (cb, eb) => fn e =>
-                                        case e of
-                                            ERel n =>
-                                            if n >= eb then
-                                                ERel (positionOf (n - eb) efv + eb - nr)
-                                            else
-                                                e
-                                          | _ => e,
-                bind = fn (ctx as (cb, eb), b) =>
+fun squishExp (nr, kfv, cfv, efv) =
+    U.Exp.mapB {kind = fn (kb, _, _) => fn k =>
+                                           case k of
+                                               KRel n =>
+                                               if n >= kb then
+                                                   KRel (positionOf (n - kb) kfv + kb)
+                                               else
+                                                   k
+                                             | _ => k,
+                con = fn (_, cb, _) => fn c =>
+                                          case c of
+                                              CRel n =>
+                                              if n >= cb then
+                                                  CRel (positionOf (n - cb) cfv + cb)
+                                              else
+                                                  c
+                                            | _ => c,
+                exp = fn (_, _, eb) => fn e =>
+                                          case e of
+                                              ERel n =>
+                                              if n >= eb then
+                                                  ERel (positionOf (n - eb) efv + eb - nr)
+                                              else
+                                                  e
+                                            | _ => e,
+                bind = fn (ctx as (kb, cb, eb), b) =>
                           case b of
-                              U.Exp.RelC _ => (cb + 1, eb)
-                            | U.Exp.RelE _ => (cb, eb + 1)
+                              U.Exp.RelK _ => (kb + 1, cb, eb)
+                            | U.Exp.RelC _ => (kb, cb + 1, eb)
+                            | U.Exp.RelE _ => (kb, cb, eb + 1)
                             | _ => ctx}
-               (0, nr)
+               (0, 0, nr)
 
 type state = {
      maxName : int,
@@ -173,7 +216,7 @@
 
 val basis = ref 0
 
-fun exp ((ks, ts), e as old, st : state) =
+fun exp ((ns, ks, ts), e as old, st : state) =
     case e of
         ELet (eds, e, t) =>
         let
@@ -249,21 +292,23 @@
                             val vis = map (fn (x, t, e) =>
                                               (x, t, doSubst' (e, subsLocal))) vis
 
-                            val (cfv, efv) = foldl (fn ((_, t, e), (cfv, efv)) =>
-                                                       let
-                                                           val (cfv', efv') = fvsExp nr e
-                                                           (*val () = Print.prefaces "fvsExp"
-                                                                    [("e", ElabPrint.p_exp E.empty e),
-                                                                     ("cfv", Print.PD.string
-                                                                                 (Int.toString (IS.numItems cfv'))),
-                                                                     ("efv", Print.PD.string
-                                                                                 (Int.toString (IS.numItems efv')))]*)
-                                                           val cfv'' = fvsCon t
-                                                       in
-                                                           (IS.union (cfv, IS.union (cfv', cfv'')),
-                                                            IS.union (efv, efv'))
-                                                       end)
-                                                   (IS.empty, IS.empty) vis
+                            val (kfv, cfv, efv) =
+                                foldl (fn ((_, t, e), (kfv, cfv, efv)) =>
+                                          let
+                                              val (kfv', cfv', efv') = fvsExp nr e
+                                              (*val () = Print.prefaces "fvsExp"
+                                                         [("e", ElabPrint.p_exp E.empty e),
+                                                          ("cfv", Print.PD.string
+                                                                      (Int.toString (IS.numItems cfv'))),
+                                                          ("efv", Print.PD.string
+                                                                      (Int.toString (IS.numItems efv')))]*)
+                                              val (kfv'', cfv'') = fvsCon t
+                                          in
+                                              (IS.union (kfv, IS.union (kfv', kfv'')),
+                                               IS.union (cfv, IS.union (cfv', cfv'')),
+                                               IS.union (efv, efv'))
+                                          end)
+                                      (IS.empty, IS.empty, IS.empty) vis
 
                             (*val () = Print.prefaces "Letto" [("e", ElabPrint.p_exp E.empty (old, ErrorMsg.dummySpan))]*)
                             (*val () = print ("A: " ^ Int.toString (length ts) ^ ", " ^ Int.toString (length ks) ^ "\n")*)
@@ -272,12 +317,30 @@
                                                                    ("t", ElabPrint.p_con E.empty t)]) ts
                             val () = IS.app (fn n => print ("Free: " ^ Int.toString n ^ "\n")) efv*)
 
+                            val kfv = IS.foldl (fn (x, kfv) =>
+                                                   let
+                                                       (*val () = print (Int.toString x ^ "\n")*)
+                                                       val (_, k) = List.nth (ks, x)
+                                                   in
+                                                       IS.union (kfv, fvsKind k)
+                                                   end)
+                                               kfv cfv
+
+                            val kfv = IS.foldl (fn (x, kfv) =>
+                                                   let
+                                                       (*val () = print (Int.toString x ^ "\n")*)
+                                                       val (_, t) = List.nth (ts, x)
+                                                   in
+                                                       IS.union (kfv, #1 (fvsCon t))
+                                                   end)
+                                               kfv efv
+
                             val cfv = IS.foldl (fn (x, cfv) =>
                                                    let
                                                        (*val () = print (Int.toString x ^ "\n")*)
                                                        val (_, t) = List.nth (ts, x)
                                                    in
-                                                       IS.union (cfv, fvsCon t)
+                                                       IS.union (cfv, #2 (fvsCon t))
                                                    end)
                                                cfv efv
                             (*val () = print "B\n"*)
@@ -299,6 +362,10 @@
                                                               val e = (ENamed n, loc)
 
                                                               val e = IS.foldr (fn (x, e) =>
+                                                                                   (EKApp (e, (KRel x, loc)), loc))
+                                                                               e kfv
+
+                                                              val e = IS.foldr (fn (x, e) =>
                                                                                    (ECApp (e, (CRel x, loc)), loc))
                                                                                e cfv
 
@@ -311,6 +378,7 @@
                                                           end)
                                                       vis
 
+                            val kfv = IS.listItems kfv
                             val cfv = IS.listItems cfv
                             val efv = IS.listItems efv
 
@@ -324,17 +392,17 @@
 
                                                   (*val () = Print.prefaces "squishCon"
                                                                           [("t", ElabPrint.p_con E.empty t)]*)
-                                                  val t = squishCon cfv t
+                                                  val t = squishCon (kfv, cfv) t
                                                   (*val () = Print.prefaces "squishExp"
                                                                           [("e", ElabPrint.p_exp E.empty e)]*)
-                                                  val e = squishExp (nr, cfv, efv) e
+                                                  val e = squishExp (nr, kfv, cfv, efv) e
 
                                                   (*val () = print ("Avail: " ^ Int.toString (length ts) ^ "\n")*)
                                                   val (e, t) = foldl (fn (ex, (e, t)) =>
                                                                          let
                                                                              (*val () = print (Int.toString ex ^ "\n")*)
                                                                              val (name, t') = List.nth (ts, ex)
-                                                                             val t' = squishCon cfv t'
+                                                                             val t' = squishCon (kfv, cfv) t'
                                                                          in
                                                                              ((EAbs (name,
                                                                                      t',
@@ -360,6 +428,17 @@
                                                                                       t), loc))
                                                                          end)
                                                                      (e, t) cfv
+
+                                                  val (e, t) = foldl (fn (kx, (e, t)) =>
+                                                                         let
+                                                                             val name = List.nth (ns, kx)
+                                                                         in
+                                                                             ((EKAbs (name,
+                                                                                      e), loc),
+                                                                              (TKFun (name,
+                                                                                      t), loc))
+                                                                         end)
+                                                                     (e, t) kfv
                                               in
                                                   (*Print.prefaces "Have a vi"
                                                                  [("x", Print.PD.string x),
@@ -391,11 +470,12 @@
 
 fun default (ctx, d, st) = (d, st)
 
-fun bind ((ks, ts), b) =
+fun bind ((ns, ks, ts), b) =
     case b of
-        U.Decl.RelC p => (p :: ks, map (fn (name, t) => (name, E.liftConInCon 0 t)) ts)
-      | U.Decl.RelE p => (ks, p :: ts)
-      | _ => (ks, ts)                        
+        U.Decl.RelK x => (x :: ns, ks, ts)
+      | U.Decl.RelC p => (ns, p :: ks, map (fn (name, t) => (name, E.liftConInCon 0 t)) ts)
+      | U.Decl.RelE p => (ns, ks, p :: ts)
+      | _ => (ns, ks, ts)                        
 
 val unnestDecl = U.Decl.foldMapB {kind = kind,
                                   con = default,
@@ -405,7 +485,7 @@
                                   str = default,
                                   decl = default,
                                   bind = bind}
-                                 ([], [])
+                                 ([], [], [])
 
 fun unnest file =
     let