# HG changeset patch # User Adam Chlipala # Date 1236558861 14400 # Node ID 4a125bbc602d7eb3e392b9fe7dad767b0a563933 # Parent b98f547a6a458bf20c63dafab1d77c9f84cf58c8 Conversion of functions to CPS, to facilitate ServerCall diff -r b98f547a6a45 -r 4a125bbc602d src/compiler.sig --- a/src/compiler.sig Sun Mar 08 13:41:55 2009 -0400 +++ b/src/compiler.sig Sun Mar 08 20:34:21 2009 -0400 @@ -94,11 +94,13 @@ val toCore_untangle : (string, Core.file) transform val toShake1 : (string, Core.file) transform val toRpcify : (string, Core.file) transform + val toCore_untangle2 : (string, Core.file) transform + val toShake2 : (string, Core.file) transform val toTag : (string, Core.file) transform val toReduce : (string, Core.file) transform val toUnpoly : (string, Core.file) transform val toSpecialize : (string, Core.file) transform - val toShake2 : (string, Core.file) transform + val toShake3 : (string, Core.file) transform val toMonoize : (string, Mono.file) transform val toMono_opt1 : (string, Mono.file) transform val toUntangle : (string, Mono.file) transform diff -r b98f547a6a45 -r 4a125bbc602d src/compiler.sml --- a/src/compiler.sml Sun Mar 08 13:41:55 2009 -0400 +++ b/src/compiler.sml Sun Mar 08 20:34:21 2009 -0400 @@ -453,12 +453,15 @@ val toRpcify = transform rpcify "rpcify" o toShake1 +val toCore_untangle2 = transform core_untangle "core_untangle2" o toRpcify +val toShake2 = transform shake "shake2" o toCore_untangle2 + val tag = { func = Tag.tag, print = CorePrint.p_file CoreEnv.empty } -val toTag = transform tag "tag" o toRpcify +val toTag = transform tag "tag" o toCore_untangle2 val reduce = { func = Reduce.reduce, @@ -481,14 +484,14 @@ val toSpecialize = transform specialize "specialize" o toUnpoly -val toShake2 = transform shake "shake2" o toSpecialize +val toShake3 = transform shake "shake3" o toSpecialize val monoize = { func = Monoize.monoize CoreEnv.empty, print = MonoPrint.p_file MonoEnv.empty } -val toMonoize = transform monoize "monoize" o toShake2 +val toMonoize = transform monoize "monoize" o toShake3 val mono_opt = { func = MonoOpt.optimize, diff -r b98f547a6a45 -r 4a125bbc602d src/core_env.sig --- a/src/core_env.sig Sun Mar 08 13:41:55 2009 -0400 +++ b/src/core_env.sig Sun Mar 08 20:34:21 2009 -0400 @@ -65,5 +65,7 @@ val declBinds : env -> Core.decl -> env val patBinds : env -> Core.pat -> env + + val patBindsN : Core.pat -> int end diff -r b98f547a6a45 -r 4a125bbc602d src/core_env.sml --- a/src/core_env.sml Sun Mar 08 13:41:55 2009 -0400 +++ b/src/core_env.sml Sun Mar 08 20:34:21 2009 -0400 @@ -342,4 +342,13 @@ | PCon (_, _, _, SOME p) => patBinds env p | PRecord xps => foldl (fn ((_, p, _), env) => patBinds env p) env xps +fun patBindsN (p, loc) = + case p of + PWild => 0 + | PVar _ => 1 + | PPrim _ => 0 + | PCon (_, _, _, NONE) => 0 + | PCon (_, _, _, SOME p) => patBindsN p + | PRecord xps => foldl (fn ((_, p, _), count) => count + patBindsN p) 0 xps + end diff -r b98f547a6a45 -r 4a125bbc602d src/mono_util.sml --- a/src/mono_util.sml Sun Mar 08 13:41:55 2009 -0400 +++ b/src/mono_util.sml Sun Mar 08 20:34:21 2009 -0400 @@ -350,12 +350,14 @@ fn e' => (ESignalSource e', loc)) - | EServerCall (n, ek, t) => - S.bind2 (mfe ctx ek, - fn ek' => - S.map2 (mft t, - fn t' => - (EServerCall (n, ek', t'), loc))) + | EServerCall (s, ek, t) => + S.bind2 (mfe ctx s, + fn s' => + S.bind2 (mfe ctx ek, + fn ek' => + S.map2 (mft t, + fn t' => + (EServerCall (s', ek', t'), loc)))) in mfe end diff -r b98f547a6a45 -r 4a125bbc602d src/reduce_local.sig --- a/src/reduce_local.sig Sun Mar 08 13:41:55 2009 -0400 +++ b/src/reduce_local.sig Sun Mar 08 20:34:21 2009 -0400 @@ -30,5 +30,6 @@ signature REDUCE_LOCAL = sig val reduce : Core.file -> Core.file + val reduceExp : Core.exp -> Core.exp end diff -r b98f547a6a45 -r 4a125bbc602d src/reduce_local.sml --- a/src/reduce_local.sml Sun Mar 08 13:41:55 2009 -0400 +++ b/src/reduce_local.sml Sun Mar 08 20:34:21 2009 -0400 @@ -51,7 +51,7 @@ let fun find (n', env, nudge, lift) = case env of - [] => raise Fail "ReduceLocal.exp: ERel" + [] => (ERel (n + nudge), loc) | Lift lift' :: rest => find (n', rest, nudge + lift', lift + lift') | Unknown :: rest => if n' = 0 then @@ -156,4 +156,6 @@ map doDecl file end +val reduceExp = exp [] + end diff -r b98f547a6a45 -r 4a125bbc602d src/rpcify.sml --- a/src/rpcify.sml Sun Mar 08 13:41:55 2009 -0400 +++ b/src/rpcify.sml Sun Mar 08 20:34:21 2009 -0400 @@ -40,6 +40,12 @@ val compare = String.compare end) +fun multiLiftExpInExp n e = + if n = 0 then + e + else + multiLiftExpInExp (n - 1) (E.liftExpInExp 0 e) + val ssBasis = SS.addList (SS.empty, ["requestHeader", "query", @@ -54,10 +60,13 @@ type state = { cpsed : int IM.map, + cpsed_range : con IM.map, cps_decls : (string * int * con * exp * string) list, exported : IS.set, - export_decls : decl list + export_decls : decl list, + + maxName : int } fun frob file = @@ -95,21 +104,30 @@ val ssids = whichIds ssBasis val csids = whichIds csBasis - val serverSide = sideish (ssBasis, ssids) - val clientSide = sideish (csBasis, csids) + fun sideish' (basis, ids) extra = + sideish (basis, IM.foldli (fn (id, _, ids) => IS.add (ids, id)) ids extra) + + val serverSide = sideish' (ssBasis, ssids) + val clientSide = sideish' (csBasis, csids) val tfuncs = foldl (fn ((d, _), tfuncs) => let - fun doOne ((_, n, t, _, _), tfuncs) = + fun doOne ((x, n, t, e, _), tfuncs) = let - fun crawl (t, args) = - case #1 t of - CApp ((CFfi ("Basis", "transaction"), _), ran) => SOME (rev args, ran) - | TFun (arg, rest) => crawl (rest, arg :: args) + val loc = #2 e + + fun crawl (t, e, args) = + case (#1 t, #1 e) of + (CApp (_, ran), _) => + SOME (x, rev args, ran, e) + | (TFun (arg, rest), EAbs (x, _, _, e)) => + crawl (rest, e, (x, arg) :: args) + | (TFun (arg, rest), _) => + crawl (rest, (EApp (e, (ERel (length args), loc)), loc), ("x", arg) :: args) | _ => NONE in - case crawl (t, []) of + case crawl (t, e, []) of NONE => tfuncs | SOME sg => IM.insert (tfuncs, n, sg) end @@ -127,44 +145,242 @@ (EApp ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), loc), _), _), t1), _), t2), _), (EFfi ("Basis", "transaction_monad"), _)), _), - trans1), _), + (ECase (ed, pes, {disc, ...}), _)), _), trans2) => - (case (serverSide trans1, clientSide trans1, serverSide trans2, clientSide trans2) of - (true, false, false, true) => - let - fun getApp (e, args) = - case #1 e of - ENamed n => (n, args) - | EApp (e1, e2) => getApp (e1, e2 :: args) - | _ => (ErrorMsg.errorAt loc "Mixed client/server code doesn't use a named function for server part"; - (0, [])) + let + val e' = (EFfi ("Basis", "bind"), loc) + val e' = (ECApp (e', (CFfi ("Basis", "transaction"), loc)), loc) + val e' = (ECApp (e', t1), loc) + val e' = (ECApp (e', t2), loc) + val e' = (EApp (e', (EFfi ("Basis", "transaction_monad"), loc)), loc) - val (n, args) = getApp (trans1, []) + val (pes, st) = ListUtil.foldlMap (fn ((p, e), st) => + let + val e' = (EApp (e', e), loc) + val e' = (EApp (e', + multiLiftExpInExp (E.patBindsN p) + trans2), loc) + val (e', st) = doExp (e', st) + in + ((p, e'), st) + end) st pes + in + (ECase (ed, pes, {disc = disc, + result = (CApp ((CFfi ("Basis", "transaction"), loc), t2), loc)}), + st) + end - val (exported, export_decls) = - if IS.member (#exported st, n) then - (#exported st, #export_decls st) - else - (IS.add (#exported st, n), - (DExport (Rpc, n), loc) :: #export_decls st) + | EApp ( + (EApp + ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), loc), _), _), t1), _), t2), _), + (EFfi ("Basis", "transaction_monad"), _)), _), + (EServerCall (n, es, ke, t), _)), _), + trans2) => + let + val e' = (EFfi ("Basis", "bind"), loc) + val e' = (ECApp (e', (CFfi ("Basis", "transaction"), loc)), loc) + val e' = (ECApp (e', t), loc) + val e' = (ECApp (e', t2), loc) + val e' = (EApp (e', (EFfi ("Basis", "transaction_monad"), loc)), loc) + val e' = (EApp (e', (EApp (E.liftExpInExp 0 ke, (ERel 0, loc)), loc)), loc) + val e' = (EApp (e', E.liftExpInExp 0 trans2), loc) + val e' = (EAbs ("x", t, t2, e'), loc) + val e' = (EServerCall (n, es, e', t), loc) + val (e', st) = doExp (e', st) + in + (#1 e', st) + end - val st = {cpsed = #cpsed st, - cps_decls = #cps_decls st, + | EApp ( + (EApp + ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), loc), _), _), _), _), t3), _), + (EFfi ("Basis", "transaction_monad"), _)), _), + (EApp ((EApp + ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), _), _), _), t1), _), t2), _), + (EFfi ("Basis", "transaction_monad"), _)), _), + trans1), _), trans2), _)), _), + trans3) => + let + val e'' = (EFfi ("Basis", "bind"), loc) + val e'' = (ECApp (e'', (CFfi ("Basis", "transaction"), loc)), loc) + val e'' = (ECApp (e'', t2), loc) + val e'' = (ECApp (e'', t3), loc) + val e'' = (EApp (e'', (EFfi ("Basis", "transaction_monad"), loc)), loc) + val e'' = (EApp (e'', (EApp (E.liftExpInExp 0 trans2, (ERel 0, loc)), loc)), loc) + val e'' = (EApp (e'', E.liftExpInExp 0 trans3), loc) + val e'' = (EAbs ("x", t1, (CApp ((CFfi ("Basis", "transaction"), loc), t3), loc), e''), loc) - exported = exported, - export_decls = export_decls} + val e' = (EFfi ("Basis", "bind"), loc) + val e' = (ECApp (e', (CFfi ("Basis", "transaction"), loc)), loc) + val e' = (ECApp (e', t1), loc) + val e' = (ECApp (e', t3), loc) + val e' = (EApp (e', (EFfi ("Basis", "transaction_monad"), loc)), loc) + val e' = (EApp (e', trans1), loc) + val e' = (EApp (e', e''), loc) + val (e', st) = doExp (e', st) + in + (#1 e', st) + end - val ran = - case IM.find (tfuncs, n) of - NONE => (Print.prefaces "BAD" [("e", CorePrint.p_exp CoreEnv.empty (e, loc))]; - raise Fail "Rpcify: Undetected transaction function") - | SOME (_, ran) => ran - in - (EServerCall (n, args, trans2, ran), st) - end - | _ => (e, st)) + | EApp ( + (EApp + ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), _), _), _), _), _), _), _), + (EFfi ("Basis", "transaction_monad"), _)), _), + _), loc), + (EAbs (_, _, _, (EWrite _, _)), _)) => (e, st) + + | EApp ( + (EApp + ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), _), _), _), t1), _), t2), _), + (EFfi ("Basis", "transaction_monad"), _)), _), + trans1), loc), + trans2) => + let + (*val () = Print.prefaces "Default" + [("e", CorePrint.p_exp CoreEnv.empty (e, ErrorMsg.dummySpan))]*) + + fun getApp (e', args) = + case #1 e' of + ENamed n => (n, args) + | EApp (e1, e2) => getApp (e1, e2 :: args) + | _ => (ErrorMsg.errorAt loc "Mixed client/server code doesn't use a named function for server part"; + Print.prefaces "Bad" [("e", CorePrint.p_exp CoreEnv.empty (e, ErrorMsg.dummySpan))]; + (0, [])) + in + case (serverSide (#cpsed_range st) trans1, clientSide (#cpsed_range st) trans1, + serverSide (#cpsed_range st) trans2, clientSide (#cpsed_range st) trans2) of + (true, false, _, true) => + let + val (n, args) = getApp (trans1, []) + + val (exported, export_decls) = + if IS.member (#exported st, n) then + (#exported st, #export_decls st) + else + (IS.add (#exported st, n), + (DExport (Rpc, n), loc) :: #export_decls st) + + val st = {cpsed = #cpsed st, + cpsed_range = #cpsed_range st, + cps_decls = #cps_decls st, + + exported = exported, + export_decls = export_decls, + + maxName = #maxName st} + + val ran = + case IM.find (tfuncs, n) of + NONE => (Print.prefaces "BAD" [("e", CorePrint.p_exp CoreEnv.empty (e, loc))]; + raise Fail ("Rpcify: Undetected transaction function " ^ Int.toString n)) + | SOME (_, _, ran, _) => ran + + val e' = EServerCall (n, args, trans2, ran) + in + (EServerCall (n, args, trans2, ran), st) + end + | (true, true, _, _) => + let + val (n, args) = getApp (trans1, []) + + fun makeCall n' = + let + val e = (ENamed n', loc) + val e = (EApp (e, trans2), loc) + in + #1 (foldl (fn (arg, e) => (EApp (e, arg), loc)) e args) + end + in + case IM.find (#cpsed_range st, n) of + SOME kdom => + (case args of + [] => raise Fail "Rpcify: cps'd function lacks first argument" + | ke :: args => + let + val ke' = (EFfi ("Basis", "bind"), loc) + val ke' = (ECApp (ke', (CFfi ("Basis", "transaction"), loc)), loc) + val ke' = (ECApp (ke', kdom), loc) + val ke' = (ECApp (ke', t2), loc) + val ke' = (EApp (ke', (EFfi ("Basis", "transaction_monad"), loc)), loc) + val ke' = (EApp (ke', (EApp (E.liftExpInExp 0 ke, (ERel 0, loc)), loc)), loc) + val ke' = (EApp (ke', E.liftExpInExp 0 trans2), loc) + val ke' = (EAbs ("x", kdom, + (CApp ((CFfi ("Basis", "transaction"), loc), t2), loc), + ke'), loc) + + val e' = (ENamed n, loc) + val e' = (EApp (e', ke'), loc) + val e' = foldl (fn (arg, e') => (EApp (e', arg), loc)) e' args + val (e', st) = doExp (e', st) + in + (#1 e', st) + end) + | NONE => + case IM.find (#cpsed st, n) of + SOME n' => (makeCall n', st) + | NONE => + let + val (name, fargs, ran, e) = + case IM.find (tfuncs, n) of + NONE => (Print.prefaces "BAD" [("e", + CorePrint.p_exp CoreEnv.empty (e, loc))]; + raise Fail "Rpcify: Undetected transaction function [2]") + | SOME x => x + + val n' = #maxName st + + val st = {cpsed = IM.insert (#cpsed st, n, n'), + cpsed_range = IM.insert (#cpsed_range st, n', ran), + cps_decls = #cps_decls st, + exported = #exported st, + export_decls = #export_decls st, + maxName = n' + 1} + + val unit = (TRecord (CRecord ((KType, loc), []), loc), loc) + val body = (EFfi ("Basis", "bind"), loc) + val body = (ECApp (body, (CFfi ("Basis", "transaction"), loc)), loc) + val body = (ECApp (body, t1), loc) + val body = (ECApp (body, unit), loc) + val body = (EApp (body, (EFfi ("Basis", "transaction_monad"), loc)), loc) + val body = (EApp (body, e), loc) + val body = (EApp (body, (ERel (length args), loc)), loc) + val bt = (CApp ((CFfi ("Basis", "transaction"), loc), unit), loc) + val (body, bt) = foldr (fn ((x, t), (body, bt)) => + ((EAbs (x, t, bt, body), loc), + (TFun (t, bt), loc))) + (body, bt) fargs + val kt = (TFun (ran, (CApp ((CFfi ("Basis", "transaction"), loc), + unit), + loc)), loc) + val body = (EAbs ("k", kt, bt, body), loc) + val bt = (TFun (kt, bt), loc) + + val (body, st) = doExp (body, st) + + val vi = (name ^ "_cps", + n', + bt, + body, + "") + + val st = {cpsed = #cpsed st, + cpsed_range = #cpsed_range st, + cps_decls = vi :: #cps_decls st, + exported = #exported st, + export_decls = #export_decls st, + maxName = #maxName st} + in + (makeCall n', st) + end + end + | _ => (e, st) + end | _ => (e, st) + and doExp (e, st) = U.Exp.foldMap {kind = fn x => x, + con = fn x => x, + exp = exp} st (ReduceLocal.reduceExp e) + fun decl (d, st : state) = let val (d, st) = U.Decl.foldMap {kind = fn x => x, @@ -181,18 +397,24 @@ | (_, loc) => [d, (DValRec ds, loc)], #export_decls st), {cpsed = #cpsed st, + cpsed_range = #cpsed_range st, cps_decls = [], exported = #exported st, - export_decls = []}) + export_decls = [], + + maxName = #maxName st}) end val (file, _) = ListUtil.foldlMapConcat decl {cpsed = IM.empty, + cpsed_range = IM.empty, cps_decls = [], exported = IS.empty, - export_decls = []} + export_decls = [], + + maxName = U.File.maxName file + 1} file in file diff -r b98f547a6a45 -r 4a125bbc602d tests/rpcM.ur --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tests/rpcM.ur Sun Mar 08 20:34:21 2009 -0400 @@ -0,0 +1,33 @@ +datatype list t = Nil | Cons of t * list t + +sequence s + +fun main () : transaction page = + let + fun getIndices srcs = + case srcs of + Nil => return Nil + | Cons (src, srcs') => + i <- nextval s; + set src i; + ls <- getIndices srcs'; + return (Cons (i, ls)) + + fun show ls = + case ls of + Nil => + | Cons (x, ls') => {[x]}
{show ls'}
+ in + src1 <- source 0; + src2 <- source 1; + s <- source Nil; + return +