adamc@1185: (* Copyright (c) 2008-2010, Adam Chlipala adamc@315: * All rights reserved. adamc@315: * adamc@315: * Redistribution and use in source and binary forms, with or without adamc@315: * modification, are permitted provided that the following conditions are met: adamc@315: * adamc@315: * - Redistributions of source code must retain the above copyright notice, adamc@315: * this list of conditions and the following disclaimer. adamc@315: * - Redistributions in binary form must reproduce the above copyright notice, adamc@315: * this list of conditions and the following disclaimer in the documentation adamc@315: * and/or other materials provided with the distribution. adamc@315: * - The names of contributors may not be used to endorse or promote products adamc@315: * derived from this software without specific prior written permission. adamc@315: * adamc@315: * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" adamc@315: * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE adamc@315: * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE adamc@315: * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE adamc@315: * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR adamc@315: * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF adamc@315: * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS adamc@315: * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN adamc@315: * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) adamc@315: * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE adamc@315: * POSSIBILITY OF SUCH DAMAGE. adamc@315: *) adamc@315: adamc@315: (* Simplify a Core program by repeating polymorphic function definitions *) adamc@315: adamc@315: structure Unpoly :> UNPOLY = struct adamc@315: adamc@315: open Core adamc@315: adamc@315: structure E = CoreEnv adamc@315: structure U = CoreUtil adamc@315: adamc@315: structure IS = IntBinarySet adamc@315: structure IM = IntBinaryMap adamc@315: adamc@315: adamc@315: (** The actual specialization *) adamc@315: adamc@315: val liftConInCon = E.liftConInCon adamc@315: val subConInCon = E.subConInCon adamc@315: adamc@315: val liftConInExp = E.liftConInExp adamc@315: val subConInExp = E.subConInExp adamc@315: adamc@1185: val isOpen = U.Con.existsB {kind = fn _ => false, adamc@1185: con = fn (n, c) => adamc@1185: case c of adamc@1185: CRel n' => n' >= n adamc@1185: | _ => false, adamc@1185: bind = fn (n, b) => adamc@1185: case b of adamc@1185: U.Con.RelC _ => n + 1 adamc@1185: | _ => n} 0 adamc@399: adamc@316: fun unpolyNamed (xn, rep) = adamc@316: U.Exp.map {kind = fn k => k, adamc@316: con = fn c => c, adamc@316: exp = fn e => adamc@316: case e of adamc@399: ECApp (e', _) => adamc@325: let adamc@325: fun isTheOne (e, _) = adamc@325: case e of adamc@325: ENamed xn' => xn' = xn adamc@325: | ECApp (e, _) => isTheOne e adamc@325: | _ => false adamc@325: in adamc@325: if isTheOne e' then adamc@399: rep adamc@325: else adamc@325: e adamc@325: end adamc@316: | _ => e} adamc@316: adamc@794: structure M = BinaryMapFn(struct adamc@794: type ord_key = con list adamc@794: val compare = Order.joinL U.Con.compare adamc@794: end) adamc@794: adamc@794: type func = { adamc@794: kinds : kind list, adamc@794: defs : (string * int * con * exp * string) list, adamc@794: replacements : int M.map adamc@794: } adamc@794: adamc@315: type state = { adamc@794: funcs : func IM.map, adamc@315: decls : decl list, adamc@315: nextName : int adamc@315: } adamc@315: adamc@315: fun kind (k, st) = (k, st) adamc@315: adamc@315: fun con (c, st) = (c, st) adamc@315: adamc@315: fun exp (e, st : state) = adamc@315: case e of adamc@315: ECApp _ => adamc@315: let adamc@315: fun unravel (e, cargs) = adamc@315: case e of adamc@315: ECApp ((e, _), c) => unravel (e, c :: cargs) adamc@315: | ENamed n => SOME (n, rev cargs) adamc@315: | _ => NONE adamc@315: in adamc@315: case unravel (e, []) of adamc@315: NONE => (e, st) adamc@315: | SOME (n, cargs) => adamc@399: if List.exists isOpen cargs then adamc@399: (e, st) adamc@399: else adamc@399: case IM.find (#funcs st, n) of adamc@399: NONE => (e, st) adamc@794: | SOME {kinds = ks, defs = vis, replacements} => adamc@1276: let adamc@1276: val cargs = map ReduceLocal.reduceCon cargs adamc@1276: in adamc@1276: case M.find (replacements, cargs) of adamc@1276: SOME n => (ENamed n, st) adamc@1276: | NONE => adamc@1276: let adamc@1276: val old_vis = vis adamc@1276: val (vis, (thisName, nextName)) = adamc@1276: ListUtil.foldlMap adamc@1276: (fn ((x, n', t, e, s), (thisName, nextName)) => adamc@1276: ((x, nextName, n', t, e, s), adamc@1276: (if n' = n then nextName else thisName, adamc@1276: nextName + 1))) adamc@1276: (0, #nextName st) vis adamc@315: adamc@1276: fun specialize (x, n, n_old, t, e, s) = adamc@1276: let adamc@1276: fun trim (t, e, cargs) = adamc@1276: case (t, e, cargs) of adamc@1276: ((TCFun (_, _, t), _), adamc@1276: (ECAbs (_, _, e), _), adamc@1276: carg :: cargs) => adamc@1276: let adamc@1276: val t = subConInCon (length cargs, carg) t adamc@1276: val e = subConInExp (length cargs, carg) e adamc@1276: in adamc@1276: trim (t, e, cargs) adamc@1276: end adamc@1276: | (_, _, []) => SOME (t, e) adamc@1276: | _ => NONE adamc@1276: in adamc@1276: (*Print.prefaces "specialize" adamc@1276: [("n", Print.PD.string (Int.toString n)), adamc@1276: ("nold", Print.PD.string (Int.toString n_old)), adamc@1276: ("t", CorePrint.p_con CoreEnv.empty t), adamc@1276: ("e", CorePrint.p_exp CoreEnv.empty e), adamc@1276: ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*) adamc@1276: Option.map (fn (t, e) => (x, n, n_old, t, e, s)) adamc@1276: (trim (t, e, cargs)) adamc@1276: end adamc@315: adamc@1276: val vis = List.map specialize vis adamc@1276: in adamc@1276: if List.exists (not o Option.isSome) vis orelse length cargs > length ks then adamc@1276: (e, st) adamc@1276: else adamc@1276: let adamc@1276: val vis = List.mapPartial (fn x => x) vis adamc@316: adamc@1276: val vis = map (fn (x, n, n_old, t, e, s) => adamc@1276: (x ^ "_unpoly", n, n_old, t, e, s)) vis adamc@1276: val vis' = map (fn (x, n, _, t, e, s) => adamc@1276: (x, n, t, e, s)) vis adamc@794: adamc@1276: val funcs = foldl (fn ((_, n, n_old, _, _, _), funcs) => adamc@1276: let adamc@1276: val replacements = case IM.find (funcs, n_old) of adamc@1276: NONE => M.empty adamc@1276: | SOME {replacements = r, adamc@1276: ...} => r adamc@1276: in adamc@1276: IM.insert (funcs, n_old, adamc@1276: {kinds = ks, adamc@1276: defs = old_vis, adamc@1276: replacements = M.insert (replacements, adamc@1276: cargs, adamc@1276: n)}) adamc@1276: end) (#funcs st) vis adamc@794: adamc@1276: val ks' = List.drop (ks, length cargs) adamc@794: adamc@1276: val st = {funcs = foldl (fn (vi, funcs) => adamc@1276: IM.insert (funcs, #2 vi, adamc@1276: {kinds = ks', adamc@1276: defs = vis', adamc@1276: replacements = M.empty})) adamc@1276: funcs vis', adamc@1276: decls = #decls st, adamc@1276: nextName = nextName} adamc@794: adamc@1276: val (vis', st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) => adamc@1276: let adamc@1276: val (e, st) = polyExp (e, st) adamc@1276: in adamc@1276: ((x, n, t, e, s), st) adamc@1276: end) adamc@1276: st vis' adamc@1276: in adamc@1276: (ENamed thisName, adamc@1276: {funcs = #funcs st, adamc@1276: decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st, adamc@1276: nextName = #nextName st}) adamc@1276: end adamc@1276: end adamc@1276: end adamc@315: end adamc@315: | _ => (e, st) adamc@315: adamc@794: and polyExp (x, st) = U.Exp.foldMap {kind = kind, con = con, exp = exp} st x adamc@794: adamc@315: fun decl (d, st : state) = adamc@1122: let adamc@1122: fun unravel (e, cargs) = adamc@1122: case e of adamc@1122: (ECAbs (_, k, e), _) => adamc@1122: unravel (e, k :: cargs) adamc@1122: | _ => rev cargs adamc@1122: in adamc@1122: case d of adamc@1122: DVal (vi as (x, n, t, e, s)) => adamc@1122: let adamc@1122: val cargs = unravel (e, []) adamc@315: adamc@1122: val ns = IS.singleton n adamc@1122: in adamc@1122: (d, {funcs = IM.insert (#funcs st, n, {kinds = cargs, adamc@1122: defs = [vi], adamc@1122: replacements = M.empty}), adamc@1122: decls = #decls st, adamc@1122: nextName = #nextName st}) adamc@1122: end adamc@1122: | DValRec (vis as ((x, n, t, e, s) :: rest)) => adamc@1122: let adamc@1122: val cargs = unravel (e, []) adamc@315: adamc@1122: fun unravel (e, cargs) = adamc@1122: case (e, cargs) of adamc@1122: ((ECAbs (_, k, e), _), k' :: cargs) => adamc@1122: U.Kind.compare (k, k') = EQUAL adamc@1122: andalso unravel (e, cargs) adamc@1122: | (_, []) => true adamc@1122: | _ => false adamc@1122: adamc@1122: fun deAbs (e, cargs) = adamc@1122: case (e, cargs) of adamc@1122: ((ECAbs (_, _, e), _), _ :: cargs) => deAbs (e, cargs) adamc@1122: | (_, []) => e adamc@1122: | _ => raise Fail "Unpoly: deAbs" adamc@315: adamc@1122: in adamc@1122: if List.exists (fn vi => not (unravel (#4 vi, cargs))) rest then adamc@1122: (d, st) adamc@1122: else adamc@1122: let adamc@1122: val ns = IS.addList (IS.empty, map #2 vis) adamc@1122: val nargs = length cargs adamc@315: adamc@1122: (** Verifying lack of polymorphic recursion *) adamc@315: adamc@1122: fun kind _ = false adamc@1122: fun con _ = false adamc@315: adamc@1180: fun exp (cn, e) = adamc@1122: case e of adamc@1180: orig as ECApp (e, c) => adamc@1122: let adamc@1122: fun isIrregular (e, pos) = adamc@1122: case #1 e of adamc@1122: ENamed n => adamc@1122: IS.member (ns, n) adamc@1122: andalso adamc@1122: (case #1 c of adamc@1180: CRel i => i <> nargs - pos + cn adamc@1122: | _ => true) adamc@1122: | ECApp (e, _) => isIrregular (e, pos + 1) adamc@1122: | _ => false adamc@1122: in adamc@1122: isIrregular (e, 1) adamc@1122: end adamc@1122: | _ => false adamc@315: adamc@1180: fun bind (cn, b) = adamc@1180: case b of adamc@1180: U.Exp.RelC _ => cn+1 adamc@1180: | _ => cn adamc@1180: adamc@1180: val irregular = U.Exp.existsB {kind = kind, con = con, exp = exp, bind = bind} 0 adamc@1122: in adamc@1122: if List.exists (fn x => irregular (deAbs (#4 x, cargs))) vis then adamc@1185: (d, st) adamc@1122: else adamc@1122: (d, {funcs = foldl (fn (vi, funcs) => adamc@1122: IM.insert (funcs, #2 vi, {kinds = cargs, adamc@1122: defs = vis, adamc@1122: replacements = M.empty})) adamc@1122: (#funcs st) vis, adamc@1122: decls = #decls st, adamc@1122: nextName = #nextName st}) adamc@1122: end adamc@1122: end adamc@315: adamc@1122: | _ => (d, st) adamc@1122: end adamc@315: adamc@315: val polyDecl = U.Decl.foldMap {kind = kind, con = con, exp = exp, decl = decl} adamc@315: adamc@315: fun unpoly file = adamc@315: let adamc@315: fun doDecl (d : decl, st : state) = adamc@315: let adamc@315: val (d, st) = polyDecl st d adamc@315: in adamc@315: (rev (d :: #decls st), adamc@315: {funcs = #funcs st, adamc@315: decls = [], adamc@315: nextName = #nextName st}) adamc@315: end adamc@315: adamc@315: val (ds, _) = ListUtil.foldlMapConcat doDecl adamc@315: {funcs = IM.empty, adamc@315: decls = [], adamc@315: nextName = U.File.maxName file + 1} file adamc@315: in adamc@315: ds adamc@315: end adamc@315: adamc@315: end