# HG changeset patch # User Adam Chlipala # Date 1253222269 14400 # Node ID d80734855790421e6519d966878ac30ab3f98541 # Parent 01a4d936395a1b556f89bfab1dd0a5059e86a167 Don't try to check if functions are already tail-recursive diff -r 01a4d936395a -r d80734855790 src/rpcify.sml --- a/src/rpcify.sml Thu Sep 17 17:11:23 2009 -0400 +++ b/src/rpcify.sml Thu Sep 17 17:17:49 2009 -0400 @@ -170,144 +170,74 @@ DValRec vis => if List.exists (fn (_, _, _, e, _) => makesServerCall e) vis then let - val all = foldl (fn ((_, n, _, _, _), all) => IS.add (all, n)) IS.empty vis + val rpc = foldl (fn ((_, n, _, _, _), rpc) => + IS.add (rpc, n)) (#rpc st) vis - val usesRec = U.Exp.exists {kind = fn _ => false, - con = fn _ => false, - exp = fn ENamed n => IS.member (all, n) - | _ => false} + val (cpsed, vis') = + foldl (fn (vi as (x, n, t, e, s), (cpsed, vis')) => + let + fun getArgs (t, acc) = + case #1 t of + TFun (dom, ran) => + getArgs (ran, dom :: acc) + | _ => (rev acc, t) + val (ts, ran) = getArgs (t, []) + val ran = case #1 ran of + CApp (_, ran) => ran + | _ => raise Fail "Rpcify: Tail function not transactional" + val len = length ts - val noRec = not o usesRec + val loc = #2 e + val args = ListUtil.mapi + (fn (i, _) => + (ERel (len - i - 1), loc)) + ts + val k = (EFfi ("Basis", "return"), loc) + val trans = (CFfi ("Basis", "transaction"), loc) + val k = (ECApp (k, trans), loc) + val k = (ECApp (k, ran), loc) + val k = (EApp (k, (EFfi ("Basis", "transaction_monad"), + loc)), loc) + val re = (ETailCall (n, args, k, ran, ran), loc) + val (re, _) = foldr (fn (dom, (re, ran)) => + ((EAbs ("x", dom, ran, re), + loc), + (TFun (dom, ran), loc))) + (re, ran) ts - fun tailOnly (e, _) = - case e of - EPrim _ => true - | ERel _ => true - | ENamed _ => true - | ECon (_, _, _, SOME e) => noRec e - | ECon _ => true - | EFfi _ => true - | EFfiApp (_, _, es) => List.all noRec es - | EApp (e1, e2) => noRec e2 andalso tailOnly e1 - | EAbs (_, _, _, e) => noRec e - | ECApp (e1, _) => tailOnly e1 - | ECAbs (_, _, e) => noRec e - - | EKAbs (_, e) => noRec e - | EKApp (e1, _) => tailOnly e1 - - | ERecord xes => List.all (noRec o #2) xes - | EField (e1, _, _) => noRec e1 - | EConcat (e1, _, e2, _) => noRec e1 andalso noRec e2 - | ECut (e1, _, _) => noRec e1 - | ECutMulti (e1, _, _) => noRec e1 - - | ECase (e1, pes, _) => noRec e1 andalso List.all (tailOnly o #2) pes - - | EWrite e1 => noRec e1 - - | EClosure (_, es) => List.all noRec es - - | ELet (_, _, e1, e2) => noRec e1 andalso tailOnly e2 - - | EServerCall (_, es, (EAbs (_, _, _, e), _), _, _) => - List.all noRec es andalso tailOnly e - | EServerCall (_, es, e, _, _) => List.all noRec es andalso noRec e - - | ETailCall _ => raise Fail "Rpcify: ETailCall too early" - - fun tailOnlyF e = - case #1 e of - EAbs (_, _, _, e) => tailOnlyF e - | ECAbs (_, _, e) => tailOnlyF e - | EKAbs (_, e) => tailOnlyF e - | _ => tailOnly e - - val nonTail = foldl (fn ((_, n, _, e, _), nonTail) => - if tailOnlyF e then - nonTail - else - IS.add (nonTail, n)) IS.empty vis + val be = multiLiftExpInExp (len + 1) e + val be = ListUtil.foldli + (fn (i, _, be) => + (EApp (be, (ERel (len - i), loc)), loc)) + be ts + val ne = (EFfi ("Basis", "bind"), loc) + val ne = (ECApp (ne, trans), loc) + val ne = (ECApp (ne, ran), loc) + val unit = (TRecord (CRecord ((KType, loc), []), + loc), loc) + val ne = (ECApp (ne, unit), loc) + val ne = (EApp (ne, (EFfi ("Basis", "transaction_monad"), + loc)), loc) + val ne = (EApp (ne, be), loc) + val ne = (EApp (ne, (ERel 0, loc)), loc) + val tunit = (CApp (trans, unit), loc) + val kt = (TFun (ran, tunit), loc) + val ne = (EAbs ("k", kt, tunit, ne), loc) + val (ne, res) = foldr (fn (dom, (ne, ran)) => + ((EAbs ("x", dom, ran, ne), loc), + (TFun (dom, ran), loc))) + (ne, (TFun (kt, tunit), loc)) ts + in + (IM.insert (cpsed, n, #1 re), + (x, n, res, ne, s) :: vis') + end) + (#cpsed st, []) vis in - if IS.isEmpty nonTail then - (d, {exported = #exported st, - export_decls = #export_decls st, - cpsed = #cpsed st, - rpc = IS.union (#rpc st, all)}) - else - let - val rpc = foldl (fn ((_, n, _, _, _), rpc) => - IS.add (rpc, n)) (#rpc st) vis - - val (cpsed, vis') = - foldl (fn (vi as (x, n, t, e, s), (cpsed, vis')) => - if IS.member (nonTail, n) then - let - fun getArgs (t, acc) = - case #1 t of - TFun (dom, ran) => - getArgs (ran, dom :: acc) - | _ => (rev acc, t) - val (ts, ran) = getArgs (t, []) - val ran = case #1 ran of - CApp (_, ran) => ran - | _ => raise Fail "Rpcify: Tail function not transactional" - val len = length ts - - val loc = #2 e - val args = ListUtil.mapi - (fn (i, _) => - (ERel (len - i - 1), loc)) - ts - val k = (EFfi ("Basis", "return"), loc) - val trans = (CFfi ("Basis", "transaction"), loc) - val k = (ECApp (k, trans), loc) - val k = (ECApp (k, ran), loc) - val k = (EApp (k, (EFfi ("Basis", "transaction_monad"), - loc)), loc) - val re = (ETailCall (n, args, k, ran, ran), loc) - val (re, _) = foldr (fn (dom, (re, ran)) => - ((EAbs ("x", dom, ran, re), - loc), - (TFun (dom, ran), loc))) - (re, ran) ts - - val be = multiLiftExpInExp (len + 1) e - val be = ListUtil.foldli - (fn (i, _, be) => - (EApp (be, (ERel (len - i), loc)), loc)) - be ts - val ne = (EFfi ("Basis", "bind"), loc) - val ne = (ECApp (ne, trans), loc) - val ne = (ECApp (ne, ran), loc) - val unit = (TRecord (CRecord ((KType, loc), []), - loc), loc) - val ne = (ECApp (ne, unit), loc) - val ne = (EApp (ne, (EFfi ("Basis", "transaction_monad"), - loc)), loc) - val ne = (EApp (ne, be), loc) - val ne = (EApp (ne, (ERel 0, loc)), loc) - val tunit = (CApp (trans, unit), loc) - val kt = (TFun (ran, tunit), loc) - val ne = (EAbs ("k", kt, tunit, ne), loc) - val (ne, res) = foldr (fn (dom, (ne, ran)) => - ((EAbs ("x", dom, ran, ne), loc), - (TFun (dom, ran), loc))) - (ne, (TFun (kt, tunit), loc)) ts - in - (IM.insert (cpsed, n, #1 re), - (x, n, res, ne, s) :: vis') - end - else - (cpsed, vi :: vis')) - (#cpsed st, []) vis - in - ((DValRec (rev vis'), ErrorMsg.dummySpan), - {exported = #exported st, - export_decls = #export_decls st, - cpsed = cpsed, - rpc = rpc}) - end + ((DValRec (rev vis'), ErrorMsg.dummySpan), + {exported = #exported st, + export_decls = #export_decls st, + cpsed = cpsed, + rpc = rpc}) end else (d, st) diff -r 01a4d936395a -r d80734855790 tests/tail.ur --- a/tests/tail.ur Thu Sep 17 17:11:23 2009 -0400 +++ b/tests/tail.ur Thu Sep 17 17:17:49 2009 -0400 @@ -8,8 +8,17 @@ n2 <- addEm (n - 1); return (n1 + n2) +fun addEm' n acc = + if n = 0 then + return acc + else + n1 <- rpc (one ()); + addEm' (n - 1) (n1 + acc) + fun main () = s <- source 0; - return + s' <- source 0; + return +