comparison src/rpcify.sml @ 954:2a50da66ffd8

Basic tail recursion introduction seems to be working
author Adam Chlipala <adamc@hcoop.net>
date Thu, 17 Sep 2009 16:35:11 -0400
parents ed06e25c70ef
children 01a4d936395a
comparison
equal deleted inserted replaced
953:301530da2062 954:2a50da66ffd8
30 open Core 30 open Core
31 31
32 structure U = CoreUtil 32 structure U = CoreUtil
33 structure E = CoreEnv 33 structure E = CoreEnv
34 34
35 fun multiLiftExpInExp n e =
36 if n = 0 then
37 e
38 else
39 multiLiftExpInExp (n - 1) (E.liftExpInExp 0 e)
40
35 structure IS = IntBinarySet 41 structure IS = IntBinarySet
36 structure IM = IntBinaryMap 42 structure IM = IntBinaryMap
37 43
38 structure SS = BinarySetFn(struct 44 structure SS = BinarySetFn(struct
39 type ord_key = string 45 type ord_key = string
40 val compare = String.compare 46 val compare = String.compare
41 end) 47 end)
42 48
43 type state = { 49 type state = {
44 exported : IS.set, 50 exported : IS.set,
45 export_decls : decl list 51 export_decls : decl list,
52
53 cpsed : exp' IM.map,
54 rpc : IS.set
46 } 55 }
47 56
48 fun frob file = 57 fun frob file =
49 let 58 let
50 val rpcBaseIds = foldl (fn ((d, _), rpcIds) => 59 val rpcBaseIds = foldl (fn ((d, _), rpcIds) =>
113 else 122 else
114 (IS.add (#exported st, n), 123 (IS.add (#exported st, n),
115 (DExport (Rpc ReadWrite, n), loc) :: #export_decls st) 124 (DExport (Rpc ReadWrite, n), loc) :: #export_decls st)
116 125
117 val st = {exported = exported, 126 val st = {exported = exported,
118 export_decls = export_decls} 127 export_decls = export_decls,
128 cpsed = #cpsed st,
129 rpc = #rpc st}
119 130
120 val k = (ECApp ((EFfi ("Basis", "return"), loc), 131 val k = (ECApp ((EFfi ("Basis", "return"), loc),
121 (CFfi ("Basis", "transaction"), loc)), loc) 132 (CFfi ("Basis", "transaction"), loc)), loc)
122 val k = (ECApp (k, ran), loc) 133 val k = (ECApp (k, ran), loc)
123 val k = (EApp (k, (EFfi ("Basis", "transaction_monad"), loc)), loc) 134 val k = (EApp (k, (EFfi ("Basis", "transaction_monad"), loc)), loc)
132 if IS.member (rpcBaseIds, n) then 143 if IS.member (rpcBaseIds, n) then
133 newRpc (trans, st) 144 newRpc (trans, st)
134 else 145 else
135 (e, st) 146 (e, st)
136 147
148 | ENamed n =>
149 (case IM.find (#cpsed st, n) of
150 NONE => (e, st)
151 | SOME re => (re, st))
152
137 | _ => (e, st) 153 | _ => (e, st)
138 end 154 end
139 155
140 and doExp (e, st) = U.Exp.foldMap {kind = fn x => x, 156 and doExp (e, st) = U.Exp.foldMap {kind = fn x => x,
141 con = fn x => x, 157 con = fn x => x,
142 exp = exp} st (ReduceLocal.reduceExp e) 158 exp = exp} st (ReduceLocal.reduceExp e)
143 159
144 fun decl (d, st : state) = 160 fun decl (d, st : state) =
145 let 161 let
162 val makesServerCall = U.Exp.exists {kind = fn _ => false,
163 con = fn _ => false,
164 exp = fn EFfi ("Basis", "rpc") => true
165 | ENamed n => IS.member (#rpc st, n)
166 | _ => false}
167
168 val (d, st) =
169 case #1 d of
170 DValRec vis =>
171 if List.exists (fn (_, _, _, e, _) => makesServerCall e) vis then
172 let
173 val all = foldl (fn ((_, n, _, _, _), all) => IS.add (all, n)) IS.empty vis
174
175 val usesRec = U.Exp.exists {kind = fn _ => false,
176 con = fn _ => false,
177 exp = fn ENamed n => IS.member (all, n)
178 | _ => false}
179
180 val noRec = not o usesRec
181
182 fun tailOnly (e, _) =
183 case e of
184 EPrim _ => true
185 | ERel _ => true
186 | ENamed _ => true
187 | ECon (_, _, _, SOME e) => noRec e
188 | ECon _ => true
189 | EFfi _ => true
190 | EFfiApp (_, _, es) => List.all noRec es
191 | EApp (e1, e2) => noRec e2 andalso tailOnly e1
192 | EAbs (_, _, _, e) => noRec e
193 | ECApp (e1, _) => tailOnly e1
194 | ECAbs (_, _, e) => noRec e
195
196 | EKAbs (_, e) => noRec e
197 | EKApp (e1, _) => tailOnly e1
198
199 | ERecord xes => List.all (noRec o #2) xes
200 | EField (e1, _, _) => noRec e1
201 | EConcat (e1, _, e2, _) => noRec e1 andalso noRec e2
202 | ECut (e1, _, _) => noRec e1
203 | ECutMulti (e1, _, _) => noRec e1
204
205 | ECase (e1, pes, _) => noRec e1 andalso List.all (tailOnly o #2) pes
206
207 | EWrite e1 => noRec e1
208
209 | EClosure (_, es) => List.all noRec es
210
211 | ELet (_, _, e1, e2) => noRec e1 andalso tailOnly e2
212
213 | EServerCall (_, es, (EAbs (_, _, _, e), _), _, _) =>
214 List.all noRec es andalso tailOnly e
215 | EServerCall (_, es, e, _, _) => List.all noRec es andalso noRec e
216
217 | ETailCall _ => raise Fail "Rpcify: ETailCall too early"
218
219 fun tailOnlyF e =
220 case #1 e of
221 EAbs (_, _, _, e) => tailOnlyF e
222 | ECAbs (_, _, e) => tailOnlyF e
223 | EKAbs (_, e) => tailOnlyF e
224 | _ => tailOnly e
225
226 val nonTail = foldl (fn ((_, n, _, e, _), nonTail) =>
227 if tailOnlyF e then
228 nonTail
229 else
230 IS.add (nonTail, n)) IS.empty vis
231 in
232 if IS.isEmpty nonTail then
233 (d, {exported = #exported st,
234 export_decls = #export_decls st,
235 cpsed = #cpsed st,
236 rpc = IS.union (#rpc st, all)})
237 else
238 let
239 val rpc = foldl (fn ((_, n, _, _, _), rpc) =>
240 IS.add (rpc, n)) (#rpc st) vis
241
242 val (cpsed, vis') =
243 foldl (fn (vi as (x, n, t, e, s), (cpsed, vis')) =>
244 if IS.member (nonTail, n) then
245 let
246 fun getArgs (t, acc) =
247 case #1 t of
248 TFun (dom, ran) =>
249 getArgs (ran, dom :: acc)
250 | _ => (rev acc, t)
251 val (ts, ran) = getArgs (t, [])
252 val ran = case #1 ran of
253 CApp (_, ran) => ran
254 | _ => raise Fail "Rpcify: Tail function not transactional"
255 val len = length ts
256
257 val loc = #2 e
258 val args = ListUtil.mapi
259 (fn (i, _) =>
260 (ERel (len - i - 1), loc))
261 ts
262 val k = (EAbs ("x", ran, ran, (ERel 0, loc)), loc)
263 val re = (ETailCall (n, args, k, ran, ran), loc)
264 val (re, _) = foldr (fn (dom, (re, ran)) =>
265 ((EAbs ("x", dom, ran, re),
266 loc),
267 (TFun (dom, ran), loc)))
268 (re, ran) ts
269
270 val be = multiLiftExpInExp (len + 1) e
271 val be = ListUtil.foldli
272 (fn (i, _, be) =>
273 (EApp (be, (ERel (len - i), loc)), loc))
274 be ts
275 val ne = (EFfi ("Basis", "bind"), loc)
276 val trans = (CFfi ("Basis", "transaction"), loc)
277 val ne = (ECApp (ne, trans), loc)
278 val ne = (ECApp (ne, ran), loc)
279 val unit = (TRecord (CRecord ((KType, loc), []),
280 loc), loc)
281 val ne = (ECApp (ne, unit), loc)
282 val ne = (EApp (ne, (EFfi ("Basis", "transaction_monad"),
283 loc)), loc)
284 val ne = (EApp (ne, be), loc)
285 val ne = (EApp (ne, (ERel 0, loc)), loc)
286 val tunit = (CApp (trans, unit), loc)
287 val kt = (TFun (ran, tunit), loc)
288 val ne = (EAbs ("k", kt, tunit, ne), loc)
289 val (ne, res) = foldr (fn (dom, (ne, ran)) =>
290 ((EAbs ("x", dom, ran, ne), loc),
291 (TFun (dom, ran), loc)))
292 (ne, (TFun (kt, tunit), loc)) ts
293 in
294 (IM.insert (cpsed, n, #1 re),
295 (x, n, res, ne, s) :: vis')
296 end
297 else
298 (cpsed, vi :: vis'))
299 (#cpsed st, []) vis
300 in
301 ((DValRec (rev vis'), ErrorMsg.dummySpan),
302 {exported = #exported st,
303 export_decls = #export_decls st,
304 cpsed = cpsed,
305 rpc = rpc})
306 end
307 end
308 else
309 (d, st)
310 | DVal (x, n, t, e, s) =>
311 (d,
312 {exported = #exported st,
313 export_decls = #export_decls st,
314 cpsed = #cpsed st,
315 rpc = if makesServerCall e then
316 IS.add (#rpc st, n)
317 else
318 #rpc st})
319 | _ => (d, st)
320
146 val (d, st) = U.Decl.foldMap {kind = fn x => x, 321 val (d, st) = U.Decl.foldMap {kind = fn x => x,
147 con = fn x => x, 322 con = fn x => x,
148 exp = exp, 323 exp = exp,
149 decl = fn x => x} 324 decl = fn x => x}
150 st d 325 st d
151 in 326 in
152 (#export_decls st @ [d], 327 (#export_decls st @ [d],
153 {exported = #exported st, 328 {exported = #exported st,
154 export_decls = []}) 329 export_decls = [],
330 cpsed = #cpsed st,
331 rpc = #rpc st})
155 end 332 end
156 333
157 val (file, _) = ListUtil.foldlMapConcat decl 334 val (file, _) = ListUtil.foldlMapConcat decl
158 {exported = IS.empty, 335 {exported = IS.empty,
159 export_decls = []} 336 export_decls = [],
337 cpsed = IM.empty,
338 rpc = rpcBaseIds}
160 file 339 file
161 in 340 in
162 file 341 file
163 end 342 end
164 343