view src/unpoly.sml @ 2201:1091227f535a

Unnest properly in presence of kind polymorphism
author Adam Chlipala <adam@chlipala.net>
date Sun, 20 Dec 2015 13:41:35 -0500
parents 5b5c0b552f59
children
line wrap: on
line source
(* Copyright (c) 2008-2010, Adam Chlipala
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * - Redistributions of source code must retain the above copyright notice,
 *   this list of conditions and the following disclaimer.
 * - Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 * - The names of contributors may not be used to endorse or promote products
 *   derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 *)

(* Simplify a Core program by repeating polymorphic function definitions *)

structure Unpoly :> UNPOLY = struct

open Core

structure E = CoreEnv
structure U = CoreUtil

structure IS = IntBinarySet
structure IM = IntBinaryMap


(** The actual specialization *)

val liftConInCon = E.liftConInCon
val subConInCon = E.subConInCon

val liftConInExp = E.liftConInExp
val subConInExp = E.subConInExp

val isOpen = U.Con.existsB {kind = fn _ => false,
                            con = fn (n, c) =>
                                     case c of
                                         CRel n' => n' >= n
                                       | _ => false,
                            bind = fn (n, b) =>
                                      case b of
                                          U.Con.RelC _ => n + 1
                                        | _ => n} 0

fun unpolyNamed (xn, rep) =
    U.Exp.map {kind = fn k => k,
               con = fn c => c,
               exp = fn e =>
                        case e of
                            ECApp (e', _) =>
                            let
                                fun isTheOne (e, _) =
                                    case e of
                                        ENamed xn' => xn' = xn
                                      | ECApp (e, _) => isTheOne e
                                      | _ => false
                            in
                                if isTheOne e' then
                                    rep
                                else
                                    e
                            end
                          | _ => e}

structure M = BinaryMapFn(struct
                          type ord_key = con list
                          val compare = Order.joinL U.Con.compare
                          end)

type func = {
     kinds : kind list,
     defs : (string * int * con * exp * string) list,
     replacements : int M.map
}

type state = {
     funcs : func IM.map,
     decls : decl list,
     nextName : int
}

fun kind (k, st) = (k, st)

fun con (c, st) = (c, st)

fun exp (e, st : state) =
    case e of
        ECApp _ =>
        let
            fun unravel (e, cargs) =
                case e of
                    ECApp ((e, _), c) => unravel (e, c :: cargs)
                  | ENamed n => SOME (n, rev cargs)
                  | _ => NONE
        in
            case unravel (e, []) of
                NONE => (e, st)
              | SOME (n, cargs) =>
                if List.exists isOpen cargs then
                    (e, st)
                else
                    case IM.find (#funcs st, n) of
                        NONE => (e, st)
                      | SOME {kinds = ks, defs = vis, replacements} =>
                        let
                            val cargs = map ReduceLocal.reduceCon cargs
                        in
                            case M.find (replacements, cargs) of
                                SOME n => (ENamed n, st)
                              | NONE =>
                                let
                                    val old_vis = vis
                                    val (vis, (thisName, nextName)) =
                                        ListUtil.foldlMap
                                            (fn ((x, n', t, e, s), (thisName, nextName)) =>
                                                ((x, nextName, n', t, e, s),
                                                 (if n' = n then nextName else thisName,
                                                  nextName + 1)))
                                            (0, #nextName st) vis

                                    fun specialize (x, n, n_old, t, e, s) =
                                        let
                                            fun trim (t, e, cargs) =
                                                case (t, e, cargs) of
                                                    ((TCFun (_, _, t), _),
                                                     (ECAbs (_, _, e), _),
                                                     carg :: cargs) =>
                                                    let
                                                        val t = subConInCon (length cargs, carg) t
                                                        val e = subConInExp (length cargs, carg) e
                                                    in
                                                        trim (t, e, cargs)
                                                    end
                                                  | (_, _, []) => SOME (t, e)
                                                  | _ => NONE
                                        in
                                            (*Print.prefaces "specialize"
                                                             [("n", Print.PD.string (Int.toString n)),
                                                              ("nold", Print.PD.string (Int.toString n_old)),
                                                              ("t", CorePrint.p_con CoreEnv.empty t),
                                                              ("e", CorePrint.p_exp CoreEnv.empty e),
                                                              ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
                                            Option.map (fn (t, e) => (x, n, n_old, t, e, s))
                                                       (trim (t, e, cargs))
                                        end

                                    val vis = List.map specialize vis
                                in
                                    if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
                                        (e, st)
                                    else
                                        let
                                            val vis = List.mapPartial (fn x => x) vis

                                            val vis = map (fn (x, n, n_old, t, e, s) =>
                                                              (x ^ "_unpoly", n, n_old, t, e, s)) vis
                                            val vis' = map (fn (x, n, _, t, e, s) =>
                                                               (x, n, t, e, s)) vis

                                            val funcs = foldl (fn ((_, n, n_old, _, _, _), funcs) =>
                                                                  let
                                                                      val replacements = case IM.find (funcs, n_old) of
                                                                                             NONE => M.empty
                                                                                           | SOME {replacements = r,
                                                                                                   ...} => r
                                                                  in
                                                                      IM.insert (funcs, n_old,
                                                                                 {kinds = ks,
                                                                                  defs = old_vis,
                                                                                  replacements = M.insert (replacements,
                                                                                                           cargs,
                                                                                                           n)})
                                                                  end) (#funcs st) vis

                                            val ks' = List.drop (ks, length cargs)

                                            val st = {funcs = foldl (fn (vi, funcs) =>
                                                                        IM.insert (funcs, #2 vi,
                                                                                   {kinds = ks',
                                                                                    defs = vis',
                                                                                    replacements = M.empty}))
                                                                    funcs vis',
                                                      decls = #decls st,
                                                      nextName = nextName}

                                            val (vis', st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) =>
                                                                                   let
                                                                                       val (e, st) = polyExp (e, st)
                                                                                   in
                                                                                       ((x, n, t, e, s), st)
                                                                                   end)
                                                                               st vis'
                                        in
                                            (ENamed thisName,
                                             {funcs = #funcs st,
                                              decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st,
                                              nextName = #nextName st})
                                        end
                                end
                        end
        end
      | _ => (e, st)

and polyExp (x, st) = U.Exp.foldMap {kind = kind, con = con, exp = exp} st x

fun decl (d, st : state) =
    let
        fun unravel (e, cargs) =
            case e of
                (ECAbs (_, k, e), _) =>
                unravel (e, k :: cargs)
              | _ => rev cargs
    in
        case d of
            DVal (vi as (x, n, t, e, s)) =>
            let
                val cargs = unravel (e, [])

                val ns = IS.singleton n
            in
                (d, {funcs = IM.insert (#funcs st, n, {kinds = cargs,
                                                       defs = [vi],
                                                       replacements = M.empty}),
                     decls = #decls st,
                     nextName = #nextName st})
            end
          | DValRec (vis as ((x, n, t, e, s) :: rest)) =>
            let
                val cargs = unravel (e, [])

                fun unravel (e, cargs) =
                    case (e, cargs) of
                        ((ECAbs (_, k, e), _), k' :: cargs) =>
                        U.Kind.compare (k, k') = EQUAL
                        andalso unravel (e, cargs)
                      | (_, []) => true
                      | _ => false
                             
                fun deAbs (e, cargs) =
                    case (e, cargs) of
                        ((ECAbs (_, _, e), _), _ :: cargs) => deAbs (e, cargs)
                      | (_, []) => e
                      | _ => raise Fail "Unpoly: deAbs"

            in
                if List.exists (fn vi => not (unravel (#4 vi, cargs))) rest then
                    (d, st)
                else
                    let
                        val ns = IS.addList (IS.empty, map #2 vis)
                        val nargs = length cargs

                        (** Verifying lack of polymorphic recursion *)

                        fun kind _ = false
                        fun con _ = false

                        fun exp (cn, e) =
                            case e of
                                orig as ECApp (e, c) =>
                                let
                                    fun isIrregular (e, pos) =
                                        case #1 e of
                                            ENamed n =>
                                            IS.member (ns, n)
                                            andalso
                                            (case #1 c of
                                                 CRel i => i <> nargs - pos + cn
                                               | _ => true)
                                          | ECApp (e, _) => isIrregular (e, pos + 1)
                                          | _ => false
                                in
                                    isIrregular (e, 1)
                                end
                              | _ => false

                        fun bind (cn, b) =
                            case b of
                                U.Exp.RelC _ => cn+1
                              | _ => cn

                        val irregular = U.Exp.existsB {kind = kind, con = con, exp = exp, bind = bind} 0
                    in
                        if List.exists (fn x => irregular (deAbs (#4 x, cargs))) vis then
                            (d, st)
                        else
                            (d, {funcs = foldl (fn (vi, funcs) =>
                                                   IM.insert (funcs, #2 vi, {kinds = cargs,
                                                                             defs = vis,
                                                                             replacements = M.empty}))
                                               (#funcs st) vis,
                                 decls = #decls st,
                                 nextName = #nextName st})
                    end
            end

          | _ => (d, st)
    end

val polyDecl = U.Decl.foldMap {kind = kind, con = con, exp = exp, decl = decl}

fun unpoly file =
    let
        fun doDecl (d : decl, st : state) =
            let
                val (d, st) = polyDecl st d
            in
                (rev (d :: #decls st),
                 {funcs = #funcs st,
                  decls = [],
                  nextName = #nextName st})
            end

        val (ds, _) = ListUtil.foldlMapConcat doDecl
                      {funcs = IM.empty,
                       decls = [],
                       nextName = U.File.maxName file + 1} file
    in
        ds
    end

end