diff src/elab_util.sml @ 805:e2780d2f4afc

Mutual datatypes through Elaborate
author Adam Chlipala <adamc@hcoop.net>
date Sat, 16 May 2009 15:14:17 -0400
parents d20d6afc1206
children 7f871c03e3a1
line wrap: on
line diff
--- a/src/elab_util.sml	Sat May 16 13:10:52 2009 -0400
+++ b/src/elab_util.sml	Sat May 16 15:14:17 2009 -0400
@@ -568,15 +568,17 @@
                         S.map2 (con ctx c,
                              fn c' =>
                                 (SgiCon (x, n, k', c'), loc)))
-              | SgiDatatype (x, n, xs, xncs) =>
-                S.map2 (ListUtil.mapfold (fn (x, n, c) =>
-                                             case c of
-                                                 NONE => S.return2 (x, n, c)
-                                               | SOME c =>
-                                                 S.map2 (con ctx c,
-                                                      fn c' => (x, n, SOME c'))) xncs,
-                        fn xncs' =>
-                           (SgiDatatype (x, n, xs, xncs'), loc))
+              | SgiDatatype dts =>
+                S.map2 (ListUtil.mapfold (fn (x, n, xs, xncs) =>
+                                             S.map2 (ListUtil.mapfold (fn (x, n, c) =>
+                                                                          case c of
+                                                                              NONE => S.return2 (x, n, c)
+                                                                            | SOME c =>
+                                                                              S.map2 (con ctx c,
+                                                                                   fn c' => (x, n, SOME c'))) xncs,
+                                                  fn xncs' => (x, n, xs, xncs'))) dts,
+                        fn dts' =>
+                           (SgiDatatype dts', loc))
               | SgiDatatypeImp (x, n, m1, ms, s, xs, xncs) =>
                 S.map2 (ListUtil.mapfold (fn (x, n, c) =>
                                              case c of
@@ -627,8 +629,15 @@
                                                    bind (ctx, NamedC (x, n, k, NONE))
                                                  | SgiCon (x, n, k, c) =>
                                                    bind (ctx, NamedC (x, n, k, SOME c))
-                                                 | SgiDatatype (x, n, _, xncs) =>
-                                                   bind (ctx, NamedC (x, n, (KType, loc), NONE))
+                                                 | SgiDatatype dts =>
+                                                   foldl (fn ((x, n, ks, _), ctx) =>
+                                                             let
+                                                                 val k' = (KType, loc)
+                                                                 val k = foldl (fn (_, k) => (KArrow (k', k), loc))
+                                                                               k' ks
+                                                             in
+                                                                 bind (ctx, NamedC (x, n, k, NONE))
+                                                             end) ctx dts
                                                  | SgiDatatypeImp (x, n, m1, ms, s, _, _) =>
                                                    bind (ctx, NamedC (x, n, (KType, loc),
                                                                       SOME (CModProj (m1, ms, s), loc)))
@@ -753,29 +762,34 @@
                                               (case #1 d of
                                                    DCon (x, n, k, c) =>
                                                    bind (ctx, NamedC (x, n, k, SOME c))
-                                                 | DDatatype (x, n, xs, xncs) =>
+                                                 | DDatatype dts =>
                                                    let
-                                                       val ctx = bind (ctx, NamedC (x, n, (KType, loc), NONE))
+                                                       fun doOne ((x, n, xs, xncs), ctx) =
+                                                           let
+                                                               val ctx = bind (ctx, NamedC (x, n, (KType, loc), NONE))
+                                                           in
+                                                               foldl (fn ((x, _, co), ctx) =>
+                                                                         let
+                                                                             val t =
+                                                                                 case co of
+                                                                                     NONE => CNamed n
+                                                                                   | SOME t => TFun (t, (CNamed n, loc))
+                                                                                               
+                                                                             val k = (KType, loc)
+                                                                             val t = (t, loc)
+                                                                             val t = foldr (fn (x, t) =>
+                                                                                               (TCFun (Explicit,
+                                                                                                       x,
+                                                                                                       k,
+                                                                                                       t), loc))
+                                                                                           t xs
+                                                                         in
+                                                                             bind (ctx, NamedE (x, t))
+                                                                         end)
+                                                                     ctx xncs
+                                                           end
                                                    in
-                                                       foldl (fn ((x, _, co), ctx) =>
-                                                                 let
-                                                                     val t =
-                                                                         case co of
-                                                                             NONE => CNamed n
-                                                                           | SOME t => TFun (t, (CNamed n, loc))
-
-                                                                     val k = (KType, loc)
-                                                                     val t = (t, loc)
-                                                                     val t = foldr (fn (x, t) =>
-                                                                                       (TCFun (Explicit,
-                                                                                               x,
-                                                                                               k,
-                                                                                               t), loc))
-                                                                             t xs
-                                                                 in
-                                                                     bind (ctx, NamedE (x, t))
-                                                                 end)
-                                                       ctx xncs
+                                                       foldl doOne ctx dts
                                                    end
                                                  | DDatatypeImp (x, n, m, ms, x', _, _) =>
                                                    bind (ctx, NamedC (x, n, (KType, loc),
@@ -851,15 +865,18 @@
                             S.map2 (mfc ctx c,
                                     fn c' =>
                                        (DCon (x, n, k', c'), loc)))
-              | DDatatype (x, n, xs, xncs) =>
-                S.map2 (ListUtil.mapfold (fn (x, n, c) =>
-                                             case c of
-                                                 NONE => S.return2 (x, n, c)
-                                               | SOME c =>
-                                                 S.map2 (mfc ctx c,
-                                                      fn c' => (x, n, SOME c'))) xncs,
-                        fn xncs' =>
-                           (DDatatype (x, n, xs, xncs'), loc))
+              | DDatatype dts =>
+                S.map2 (ListUtil.mapfold (fn (x, n, xs, xncs) =>
+                                             S.map2 (ListUtil.mapfold (fn (x, n, c) =>
+                                                                          case c of
+                                                                              NONE => S.return2 (x, n, c)
+                                                                            | SOME c =>
+                                                                              S.map2 (mfc ctx c,
+                                                                                   fn c' => (x, n, SOME c'))) xncs,
+                                                     fn xncs' =>
+                                                        (x, n, xs, xncs'))) dts,
+                     fn dts' =>
+                        (DDatatype dts', loc))
               | DDatatypeImp (x, n, m1, ms, s, xs, xncs) =>
                 S.map2 (ListUtil.mapfold (fn (x, n, c) =>
                                              case c of
@@ -1059,9 +1076,10 @@
 and maxNameDecl (d, _) =
     case d of
         DCon (_, n, _, _) => n
-      | DDatatype (_, n, _, ns) =>
+      | DDatatype dts =>
+        foldl (fn ((_, n, _, ns), max) =>
                   foldl (fn ((_, n', _), m) => Int.max (n', m))
-                        n ns
+                        (Int.max (n, max)) ns) 0 dts
       | DDatatypeImp (_, n1, n2, _, _, _, ns) =>
         foldl (fn ((_, n', _), m) => Int.max (n', m))
               (Int.max (n1, n2)) ns
@@ -1101,9 +1119,10 @@
     case sgi of
         SgiConAbs (_, n, _) => n
       | SgiCon (_, n, _, _) => n
-      | SgiDatatype (_, n, _, ns) =>
-        foldl (fn ((_, n', _), m) => Int.max (n', m))
-              n ns
+      | SgiDatatype dts =>
+        foldl (fn ((_, n, _, ns), max) =>
+                  foldl (fn ((_, n', _), m) => Int.max (n', m))
+                        (Int.max (n, max)) ns) 0 dts
       | SgiDatatypeImp (_, n1, n2, _, _, _, ns) =>
         foldl (fn ((_, n', _), m) => Int.max (n', m))
               (Int.max (n1, n2)) ns