diff src/specialize.sml @ 193:8a70e2919e86

Specialization of single-parameter datatypes
author Adam Chlipala <adamc@hcoop.net>
date Fri, 08 Aug 2008 17:55:51 -0400
parents
children df5fd8f6913a
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/specialize.sml	Fri Aug 08 17:55:51 2008 -0400
@@ -0,0 +1,276 @@
+(* Copyright (c) 2008, Adam Chlipala
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * - Redistributions of source code must retain the above copyright notice,
+ *   this list of conditions and the following disclaimer.
+ * - Redistributions in binary form must reproduce the above copyright notice,
+ *   this list of conditions and the following disclaimer in the documentation
+ *   and/or other materials provided with the distribution.
+ * - The names of contributors may not be used to endorse or promote products
+ *   derived from this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ *)
+
+(* Simplify a Core program algebraically *)
+
+structure Specialize :> SPECIALIZE = struct
+
+open Core
+
+structure E = CoreEnv
+structure U = CoreUtil
+
+val liftConInCon = E.liftConInCon
+val subConInCon = E.subConInCon
+
+structure CK = struct
+type ord_key = con list
+val compare = Order.joinL U.Con.compare
+end
+
+structure CM = BinaryMapFn(CK)
+structure IM = IntBinaryMap
+
+type datatyp' = {
+     name : int,
+     constructors : int IM.map
+}
+
+type datatyp = {
+     name : string,
+     params : int,
+     constructors : (string * int * con option) list,
+     specializations : datatyp' CM.map
+}
+
+type state = {
+     count : int,
+     datatypes : datatyp IM.map,
+     constructors : int IM.map,
+     decls : decl list
+}
+
+fun kind (k, st) = (k, st)
+
+val isOpen = U.Con.exists {kind = fn _ => false,
+                           con = fn c =>
+                                    case c of
+                                        CRel _ => true
+                                      | _ => false}
+
+fun considerSpecialization (st : state, n, args, dt : datatyp) =
+    case CM.find (#specializations dt, args) of
+        SOME dt' => (#name dt', #constructors dt', st)
+      | NONE =>
+        let
+            val n' = #count st
+
+            fun sub t = ListUtil.foldli (fn (i, arg, t) =>
+                                            subConInCon (i, arg) t) t args
+
+            val (cons, (count, cmap)) =
+                ListUtil.foldlMap (fn ((x, n, to), (count, cmap)) =>
+                                      let
+                                          val to = Option.map sub to
+                                      in
+                                          ((x, count, to),
+                                           (count + 1,
+                                            IM.insert (cmap, n, count)))
+                                      end) (n' + 1, IM.empty) (#constructors dt)
+
+            val st = {count = count,
+                      datatypes = IM.insert (#datatypes st, n,
+                                             {name = #name dt,
+                                              params = #params dt,
+                                              constructors = #constructors dt,
+                                              specializations = CM.insert (#specializations dt,
+                                                                           args,
+                                                                           {name = n',
+                                                                            constructors = cmap})}),
+                      constructors = #constructors st,
+                      decls = #decls st}
+
+            val (cons, st) = ListUtil.foldlMap (fn ((x, n, NONE), st) => ((x, n, NONE), st)
+                                                 | ((x, n, SOME t), st) =>
+                                                   let
+                                                       val (t, st) = specCon st t
+                                                   in
+                                                       ((x, n, SOME t), st)
+                                                   end) st cons
+
+            val d = (DDatatype (#name dt ^ "_s",
+                                n',
+                                [],
+                                cons), #2 (List.hd args))
+        in
+            (n', cmap, {count = #count st,
+                        datatypes = #datatypes st,
+                        constructors = #constructors st,
+                        decls = d :: #decls st})
+        end
+
+and con (c, st : state) =
+    let
+        fun findApp (c, args) =
+            case c of
+                CApp ((c', _), arg) => findApp (c', arg :: args)
+              | CNamed n => SOME (n, args)
+              | _ => NONE
+    in
+        case findApp (c, []) of
+            SOME (n, args as (_ :: _)) =>
+            if List.exists isOpen args then
+                (c, st)
+            else
+                (case IM.find (#datatypes st, n) of
+                     NONE => (c, st)
+                   | SOME dt =>
+                     if length args <> #params dt then
+                         (c, st)
+                     else
+                         let
+                             val (n, _, st) = considerSpecialization (st, n, args, dt)
+                         in
+                             (CNamed n, st)
+                         end)
+          | _ => (c, st)
+    end
+
+and specCon st = U.Con.foldMap {kind = kind, con = con} st
+
+fun pat (p, st) =
+    case #1 p of
+        PWild => (p, st)
+      | PVar _ => (p, st)
+      | PPrim _ => (p, st)
+      | PCon (dk, PConVar pn, args as (_ :: _), po) =>
+        let
+            val (po, st) =
+                case po of
+                    NONE => (NONE, st)
+                  | SOME p =>
+                    let
+                        val (p, st) = pat (p, st)
+                    in
+                        (SOME p, st)
+                    end
+            val p = (PCon (dk, PConVar pn, args, po), #2 p)
+        in
+            if List.exists isOpen args then
+                (p, st)
+            else
+                case IM.find (#constructors st, pn) of
+                    NONE => (p, st)
+                  | SOME n =>
+                    case IM.find (#datatypes st, n) of
+                        NONE => (p, st)
+                      | SOME dt =>
+                        let
+                            val (n, cmap, st) = considerSpecialization (st, n, args, dt)
+                        in
+                            case IM.find (cmap, pn) of
+                                NONE => raise Fail "Specialize: Missing datatype constructor (pat)"
+                              | SOME pn' => ((PCon (dk, PConVar pn', [], po), #2 p), st)
+                        end
+        end
+      | PCon _ => (p, st)
+      | PRecord xps =>
+        let
+            val (xps, st) = ListUtil.foldlMap (fn ((x, p, t), st) =>
+                                                  let
+                                                      val (p, st) = pat (p, st)
+                                                  in
+                                                      ((x, p, t), st)
+                                                  end)
+                            st xps
+        in
+            ((PRecord xps, #2 p), st)
+        end
+
+fun exp (e, st) =
+    case e of
+        ECon (dk, PConVar pn, args as (_ :: _), eo) =>
+        if List.exists isOpen args then
+            (e, st)
+        else
+            (case IM.find (#constructors st, pn) of
+                 NONE => (e, st)
+               | SOME n =>
+                 case IM.find (#datatypes st, n) of
+                     NONE => (e, st)
+                   | SOME dt =>
+                     let
+                         val (n, cmap, st) = considerSpecialization (st, n, args, dt)
+                     in
+                         case IM.find (cmap, pn) of
+                             NONE => raise Fail "Specialize: Missing datatype constructor"
+                           | SOME pn' => (ECon (dk, PConVar pn', [], eo), st)
+                     end)
+      | ECase (e, pes, r) =>
+        let
+            val (pes, st) = ListUtil.foldlMap (fn ((p, e), st) =>
+                                                  let
+                                                      val (p, st) = pat (p, st)
+                                                  in
+                                                      ((p, e), st)
+                                                  end) st pes
+        in
+            (ECase (e, pes, r), st)
+        end
+      | _ => (e, st)
+
+fun decl (d, st) = (d, st)
+
+val specDecl = U.Decl.foldMap {kind = kind, con = con, exp = exp, decl = decl}
+
+fun specialize file =
+    let
+        fun doDecl (all as (d, _), st : state) =
+            case d of
+                DDatatype (x, n, xs, xnts) =>
+                ([all], {count = #count st,
+                         datatypes = IM.insert (#datatypes st, n,
+                                                {name = x,
+                                                 params = length xs,
+                                                 constructors = xnts,
+                                                 specializations = CM.empty}),
+                         constructors = foldl (fn ((_, n', _), constructors) =>
+                                                  IM.insert (constructors, n', n))
+                                              (#constructors st) xnts,
+                         decls = []})
+              | _ =>
+                let
+                    val (d, st) = specDecl st all
+                in
+                    (rev (d :: #decls st),
+                     {count = #count st,
+                      datatypes = #datatypes st,
+                      constructors = #constructors st,
+                      decls = []})
+                end
+
+        val (ds, _) = ListUtil.foldlMapConcat doDecl
+                      {count = U.File.maxName file + 1,
+                       datatypes = IM.empty,
+                       constructors = IM.empty,
+                       decls = []} file
+    in
+        ds
+    end
+
+
+end