diff src/core_util.sml @ 193:8a70e2919e86

Specialization of single-parameter datatypes
author Adam Chlipala <adamc@hcoop.net>
date Fri, 08 Aug 2008 17:55:51 -0400
parents 9bbf4d383381
children 890a61991263
line wrap: on
line diff
--- a/src/core_util.sml	Fri Aug 08 10:59:06 2008 -0400
+++ b/src/core_util.sml	Fri Aug 08 17:55:51 2008 -0400
@@ -39,6 +39,28 @@
 
 structure Kind = struct
 
+open Order
+
+fun compare ((k1, _), (k2, _)) =
+    case (k1, k2) of
+        (KType, KType) => EQUAL
+      | (KType, _) => LESS
+      | (_, KType) => GREATER
+
+      | (KArrow (d1, r1), KArrow (d2, r2)) => join (compare (d1, d2), fn () => compare (r1, r2))
+      | (KArrow _, _) => LESS
+      | (_, KArrow _) => GREATER
+
+      | (KName, KName) => EQUAL
+      | (KName, _) => LESS
+      | (_, KName) => GREATER
+
+      | (KRecord k1, KRecord k2) => compare (k1, k2)
+      | (KRecord _, _) => LESS
+      | (_, KRecord _) => GREATER
+
+      | (KUnit, KUnit) => EQUAL
+
 fun mapfold f =
     let
         fun mfk k acc =
@@ -85,6 +107,76 @@
 
 structure Con = struct
 
+open Order
+
+fun compare ((c1, _), (c2, _)) =
+    case (c1, c2) of
+        (TFun (d1, r1), TFun (d2, r2)) => join (compare (d1, d2), fn () => compare (r1, r2))
+      | (TFun _, _) => LESS
+      | (_, TFun _) => GREATER
+
+      | (TCFun (x1, k1, r1), TCFun (x2, k2, r2)) =>
+        join (String.compare (x1, x2),
+           fn () => join (Kind.compare (k1, k2),
+                          fn () => compare (r1, r2)))
+      | (TCFun _, _) => LESS
+      | (_, TCFun _) => GREATER
+
+      | (TRecord c1, TRecord c2) => compare (c1, c2)
+      | (TRecord _, _) => LESS
+      | (_, TRecord _) => GREATER
+
+      | (CRel n1, CRel n2) => Int.compare (n1, n2)
+      | (CRel _, _) => LESS
+      | (_, CRel _) => GREATER
+
+      | (CNamed n1, CNamed n2) => Int.compare (n1, n2)
+      | (CNamed _, _) => LESS
+      | (_, CNamed _) => GREATER
+
+      | (CFfi (m1, s1), CFfi (m2, s2)) => join (String.compare (m1, m2),
+                                                fn () => String.compare (s1, s2))
+      | (CFfi _, _) => LESS
+      | (_, CFfi _) => GREATER
+
+      | (CApp (f1, x1), CApp (f2, x2)) => join (compare (f1, f2),
+                                                fn () => compare (x1, x2))
+      | (CApp _, _) => LESS
+      | (_, CApp _) => GREATER
+
+      | (CAbs (x1, k1, b1), CAbs (x2, k2, b2)) =>
+        join (String.compare (x1, x2),
+              fn () => join (Kind.compare (k1, k2),
+                             fn () => compare (b1, b2)))
+      | (CAbs _, _) => LESS
+      | (_, CAbs _) => GREATER
+
+      | (CName s1, CName s2) => String.compare (s1, s2)
+      | (CName _, _) => LESS
+      | (_, CName _) => GREATER
+
+      | (CRecord (k1, xvs1), CRecord (k2, xvs2)) =>
+        join (Kind.compare (k1, k2),
+              fn () => joinL (fn ((x1, v1), (x2, v2)) =>
+                                 join (compare (x1, x2),
+                                       fn () => compare (v1, v2))) (xvs1, xvs2))
+      | (CRecord _, _) => LESS
+      | (_, CRecord _) => GREATER
+
+      | (CConcat (f1, s1), CConcat (f2, s2)) =>
+        join (compare (f1, f2),
+              fn () => compare (s1, s2))
+      | (CConcat _, _) => LESS
+      | (_, CConcat _) => GREATER
+
+      | (CFold (d1, r1), CFold (d2, r2)) =>
+        join (Kind.compare (d1, r2),
+              fn () => Kind.compare (r1, r2))
+      | (CFold _, _) => LESS
+      | (_, CFold _) => GREATER
+
+      | (CUnit, CUnit) => EQUAL
+
 datatype binder =
          Rel of string * kind
        | Named of string * int * kind * con option
@@ -201,6 +293,12 @@
         S.Return _ => true
       | S.Continue _ => false
 
+fun foldMap {kind, con} s c =
+    case mapfold {kind = fn k => fn s => S.Continue (kind (k, s)),
+                  con = fn c => fn s => S.Continue (con (c, s))} c s of
+        S.Continue v => v
+      | S.Return _ => raise Fail "CoreUtil.Con.foldMap: Impossible"
+
 end
 
 structure Exp = struct
@@ -317,8 +415,22 @@
                 S.bind2 (mfe ctx e,
                          fn e' =>
                             S.bind2 (ListUtil.mapfold (fn (p, e) =>
-                                                         S.map2 (mfe ctx e,
-                                                              fn e' => (p, e'))) pes,
+                                                          let
+                                                              fun pb ((p, _), ctx) =
+                                                                  case p of
+                                                                      PWild => ctx
+                                                                    | PVar (x, t) => bind (ctx, RelE (x, t))
+                                                                    | PPrim _ => ctx
+                                                                    | PCon (_, _, _, NONE) => ctx
+                                                                    | PCon (_, _, _, SOME p) => pb (p, ctx)
+                                                                    | PRecord xps => foldl (fn ((_, p, _), ctx) =>
+                                                                                               pb (p, ctx)) ctx xps
+                                                          in
+                                                              S.bind2 (mfp ctx p,
+                                                                       fn p' =>
+                                                                          S.map2 (mfe (pb (p', ctx)) e,
+                                                                               fn e' => (p', e')))
+                                                          end) pes,
                                     fn pes' =>
                                        S.bind2 (mfc ctx disc,
                                                 fn disc' =>
@@ -335,6 +447,45 @@
                 S.map2 (ListUtil.mapfold (mfe ctx) es,
                      fn es' =>
                         (EClosure (n, es'), loc))
+
+        and mfp ctx (pAll as (p, loc)) =
+            case p of
+                PWild => S.return2 pAll
+              | PVar (x, t) =>
+                S.map2 (mfc ctx t,
+                        fn t' =>
+                           (PVar (x, t'), loc))
+              | PPrim _ => S.return2 pAll
+              | PCon (dk, pc, args, po) =>
+                S.bind2 (mfpc ctx pc,
+                         fn pc' =>
+                            S.bind2 (ListUtil.mapfold (mfc ctx) args,
+                                     fn args' =>
+                                        S.map2 ((case po of
+                                                     NONE => S.return2 NONE
+                                                   | SOME p => S.map2 (mfp ctx p, SOME)),
+                                                fn po' =>
+                                                   (PCon (dk, pc', args', po'), loc))))
+              | PRecord xps =>
+                S.map2 (ListUtil.mapfold (fn (x, p, c) =>
+                                              S.bind2 (mfp ctx p,
+                                                       fn p' =>
+                                                          S.map2 (mfc ctx c,
+                                                                  fn c' =>
+                                                                     (x, p', c')))) xps,
+                         fn xps' =>
+                            (PRecord xps', loc))
+
+        and mfpc ctx pc =
+            case pc of
+                PConVar _ => S.return2 pc
+              | PConFfi {mod = m, datatyp, params, con, arg, kind} =>
+                S.map2 ((case arg of
+                             NONE => S.return2 NONE
+                           | SOME c => S.map2 (mfc ctx c, SOME)),
+                        fn arg' =>
+                           PConFfi {mod = m, datatyp = datatyp, params = params,
+                                    con = con, arg = arg', kind = kind})
     in
         mfe
     end