comparison src/rpcify.sml @ 642:4a125bbc602d

Conversion of functions to CPS, to facilitate ServerCall
author Adam Chlipala <adamc@hcoop.net>
date Sun, 08 Mar 2009 20:34:21 -0400
parents c5991cdb0c4b
children 96ebc6bdb5a0
comparison
equal deleted inserted replaced
641:b98f547a6a45 642:4a125bbc602d
38 structure SS = BinarySetFn(struct 38 structure SS = BinarySetFn(struct
39 type ord_key = string 39 type ord_key = string
40 val compare = String.compare 40 val compare = String.compare
41 end) 41 end)
42 42
43 fun multiLiftExpInExp n e =
44 if n = 0 then
45 e
46 else
47 multiLiftExpInExp (n - 1) (E.liftExpInExp 0 e)
48
43 val ssBasis = SS.addList (SS.empty, 49 val ssBasis = SS.addList (SS.empty,
44 ["requestHeader", 50 ["requestHeader",
45 "query", 51 "query",
46 "dml", 52 "dml",
47 "nextval"]) 53 "nextval"])
52 "set", 58 "set",
53 "alert"]) 59 "alert"])
54 60
55 type state = { 61 type state = {
56 cpsed : int IM.map, 62 cpsed : int IM.map,
63 cpsed_range : con IM.map,
57 cps_decls : (string * int * con * exp * string) list, 64 cps_decls : (string * int * con * exp * string) list,
58 65
59 exported : IS.set, 66 exported : IS.set,
60 export_decls : decl list 67 export_decls : decl list,
68
69 maxName : int
61 } 70 }
62 71
63 fun frob file = 72 fun frob file =
64 let 73 let
65 fun sideish (basis, ssids) = 74 fun sideish (basis, ssids) =
93 end 102 end
94 103
95 val ssids = whichIds ssBasis 104 val ssids = whichIds ssBasis
96 val csids = whichIds csBasis 105 val csids = whichIds csBasis
97 106
98 val serverSide = sideish (ssBasis, ssids) 107 fun sideish' (basis, ids) extra =
99 val clientSide = sideish (csBasis, csids) 108 sideish (basis, IM.foldli (fn (id, _, ids) => IS.add (ids, id)) ids extra)
109
110 val serverSide = sideish' (ssBasis, ssids)
111 val clientSide = sideish' (csBasis, csids)
100 112
101 val tfuncs = foldl 113 val tfuncs = foldl
102 (fn ((d, _), tfuncs) => 114 (fn ((d, _), tfuncs) =>
103 let 115 let
104 fun doOne ((_, n, t, _, _), tfuncs) = 116 fun doOne ((x, n, t, e, _), tfuncs) =
105 let 117 let
106 fun crawl (t, args) = 118 val loc = #2 e
107 case #1 t of 119
108 CApp ((CFfi ("Basis", "transaction"), _), ran) => SOME (rev args, ran) 120 fun crawl (t, e, args) =
109 | TFun (arg, rest) => crawl (rest, arg :: args) 121 case (#1 t, #1 e) of
122 (CApp (_, ran), _) =>
123 SOME (x, rev args, ran, e)
124 | (TFun (arg, rest), EAbs (x, _, _, e)) =>
125 crawl (rest, e, (x, arg) :: args)
126 | (TFun (arg, rest), _) =>
127 crawl (rest, (EApp (e, (ERel (length args), loc)), loc), ("x", arg) :: args)
110 | _ => NONE 128 | _ => NONE
111 in 129 in
112 case crawl (t, []) of 130 case crawl (t, e, []) of
113 NONE => tfuncs 131 NONE => tfuncs
114 | SOME sg => IM.insert (tfuncs, n, sg) 132 | SOME sg => IM.insert (tfuncs, n, sg)
115 end 133 end
116 in 134 in
117 case d of 135 case d of
125 case e of 143 case e of
126 EApp ( 144 EApp (
127 (EApp 145 (EApp
128 ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), loc), _), _), t1), _), t2), _), 146 ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), loc), _), _), t1), _), t2), _),
129 (EFfi ("Basis", "transaction_monad"), _)), _), 147 (EFfi ("Basis", "transaction_monad"), _)), _),
130 trans1), _), 148 (ECase (ed, pes, {disc, ...}), _)), _),
131 trans2) => 149 trans2) =>
132 (case (serverSide trans1, clientSide trans1, serverSide trans2, clientSide trans2) of 150 let
133 (true, false, false, true) => 151 val e' = (EFfi ("Basis", "bind"), loc)
134 let 152 val e' = (ECApp (e', (CFfi ("Basis", "transaction"), loc)), loc)
135 fun getApp (e, args) = 153 val e' = (ECApp (e', t1), loc)
136 case #1 e of 154 val e' = (ECApp (e', t2), loc)
137 ENamed n => (n, args) 155 val e' = (EApp (e', (EFfi ("Basis", "transaction_monad"), loc)), loc)
138 | EApp (e1, e2) => getApp (e1, e2 :: args) 156
139 | _ => (ErrorMsg.errorAt loc "Mixed client/server code doesn't use a named function for server part"; 157 val (pes, st) = ListUtil.foldlMap (fn ((p, e), st) =>
140 (0, [])) 158 let
141 159 val e' = (EApp (e', e), loc)
142 val (n, args) = getApp (trans1, []) 160 val e' = (EApp (e',
143 161 multiLiftExpInExp (E.patBindsN p)
144 val (exported, export_decls) = 162 trans2), loc)
145 if IS.member (#exported st, n) then 163 val (e', st) = doExp (e', st)
146 (#exported st, #export_decls st) 164 in
147 else 165 ((p, e'), st)
148 (IS.add (#exported st, n), 166 end) st pes
149 (DExport (Rpc, n), loc) :: #export_decls st) 167 in
150 168 (ECase (ed, pes, {disc = disc,
151 val st = {cpsed = #cpsed st, 169 result = (CApp ((CFfi ("Basis", "transaction"), loc), t2), loc)}),
152 cps_decls = #cps_decls st, 170 st)
153 171 end
154 exported = exported, 172
155 export_decls = export_decls} 173 | EApp (
156 174 (EApp
157 val ran = 175 ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), loc), _), _), t1), _), t2), _),
158 case IM.find (tfuncs, n) of 176 (EFfi ("Basis", "transaction_monad"), _)), _),
159 NONE => (Print.prefaces "BAD" [("e", CorePrint.p_exp CoreEnv.empty (e, loc))]; 177 (EServerCall (n, es, ke, t), _)), _),
160 raise Fail "Rpcify: Undetected transaction function") 178 trans2) =>
161 | SOME (_, ran) => ran 179 let
162 in 180 val e' = (EFfi ("Basis", "bind"), loc)
163 (EServerCall (n, args, trans2, ran), st) 181 val e' = (ECApp (e', (CFfi ("Basis", "transaction"), loc)), loc)
164 end 182 val e' = (ECApp (e', t), loc)
165 | _ => (e, st)) 183 val e' = (ECApp (e', t2), loc)
184 val e' = (EApp (e', (EFfi ("Basis", "transaction_monad"), loc)), loc)
185 val e' = (EApp (e', (EApp (E.liftExpInExp 0 ke, (ERel 0, loc)), loc)), loc)
186 val e' = (EApp (e', E.liftExpInExp 0 trans2), loc)
187 val e' = (EAbs ("x", t, t2, e'), loc)
188 val e' = (EServerCall (n, es, e', t), loc)
189 val (e', st) = doExp (e', st)
190 in
191 (#1 e', st)
192 end
193
194 | EApp (
195 (EApp
196 ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), loc), _), _), _), _), t3), _),
197 (EFfi ("Basis", "transaction_monad"), _)), _),
198 (EApp ((EApp
199 ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), _), _), _), t1), _), t2), _),
200 (EFfi ("Basis", "transaction_monad"), _)), _),
201 trans1), _), trans2), _)), _),
202 trans3) =>
203 let
204 val e'' = (EFfi ("Basis", "bind"), loc)
205 val e'' = (ECApp (e'', (CFfi ("Basis", "transaction"), loc)), loc)
206 val e'' = (ECApp (e'', t2), loc)
207 val e'' = (ECApp (e'', t3), loc)
208 val e'' = (EApp (e'', (EFfi ("Basis", "transaction_monad"), loc)), loc)
209 val e'' = (EApp (e'', (EApp (E.liftExpInExp 0 trans2, (ERel 0, loc)), loc)), loc)
210 val e'' = (EApp (e'', E.liftExpInExp 0 trans3), loc)
211 val e'' = (EAbs ("x", t1, (CApp ((CFfi ("Basis", "transaction"), loc), t3), loc), e''), loc)
212
213 val e' = (EFfi ("Basis", "bind"), loc)
214 val e' = (ECApp (e', (CFfi ("Basis", "transaction"), loc)), loc)
215 val e' = (ECApp (e', t1), loc)
216 val e' = (ECApp (e', t3), loc)
217 val e' = (EApp (e', (EFfi ("Basis", "transaction_monad"), loc)), loc)
218 val e' = (EApp (e', trans1), loc)
219 val e' = (EApp (e', e''), loc)
220 val (e', st) = doExp (e', st)
221 in
222 (#1 e', st)
223 end
224
225 | EApp (
226 (EApp
227 ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), _), _), _), _), _), _), _),
228 (EFfi ("Basis", "transaction_monad"), _)), _),
229 _), loc),
230 (EAbs (_, _, _, (EWrite _, _)), _)) => (e, st)
231
232 | EApp (
233 (EApp
234 ((EApp ((ECApp ((ECApp ((ECApp ((EFfi ("Basis", "bind"), _), _), _), t1), _), t2), _),
235 (EFfi ("Basis", "transaction_monad"), _)), _),
236 trans1), loc),
237 trans2) =>
238 let
239 (*val () = Print.prefaces "Default"
240 [("e", CorePrint.p_exp CoreEnv.empty (e, ErrorMsg.dummySpan))]*)
241
242 fun getApp (e', args) =
243 case #1 e' of
244 ENamed n => (n, args)
245 | EApp (e1, e2) => getApp (e1, e2 :: args)
246 | _ => (ErrorMsg.errorAt loc "Mixed client/server code doesn't use a named function for server part";
247 Print.prefaces "Bad" [("e", CorePrint.p_exp CoreEnv.empty (e, ErrorMsg.dummySpan))];
248 (0, []))
249 in
250 case (serverSide (#cpsed_range st) trans1, clientSide (#cpsed_range st) trans1,
251 serverSide (#cpsed_range st) trans2, clientSide (#cpsed_range st) trans2) of
252 (true, false, _, true) =>
253 let
254 val (n, args) = getApp (trans1, [])
255
256 val (exported, export_decls) =
257 if IS.member (#exported st, n) then
258 (#exported st, #export_decls st)
259 else
260 (IS.add (#exported st, n),
261 (DExport (Rpc, n), loc) :: #export_decls st)
262
263 val st = {cpsed = #cpsed st,
264 cpsed_range = #cpsed_range st,
265 cps_decls = #cps_decls st,
266
267 exported = exported,
268 export_decls = export_decls,
269
270 maxName = #maxName st}
271
272 val ran =
273 case IM.find (tfuncs, n) of
274 NONE => (Print.prefaces "BAD" [("e", CorePrint.p_exp CoreEnv.empty (e, loc))];
275 raise Fail ("Rpcify: Undetected transaction function " ^ Int.toString n))
276 | SOME (_, _, ran, _) => ran
277
278 val e' = EServerCall (n, args, trans2, ran)
279 in
280 (EServerCall (n, args, trans2, ran), st)
281 end
282 | (true, true, _, _) =>
283 let
284 val (n, args) = getApp (trans1, [])
285
286 fun makeCall n' =
287 let
288 val e = (ENamed n', loc)
289 val e = (EApp (e, trans2), loc)
290 in
291 #1 (foldl (fn (arg, e) => (EApp (e, arg), loc)) e args)
292 end
293 in
294 case IM.find (#cpsed_range st, n) of
295 SOME kdom =>
296 (case args of
297 [] => raise Fail "Rpcify: cps'd function lacks first argument"
298 | ke :: args =>
299 let
300 val ke' = (EFfi ("Basis", "bind"), loc)
301 val ke' = (ECApp (ke', (CFfi ("Basis", "transaction"), loc)), loc)
302 val ke' = (ECApp (ke', kdom), loc)
303 val ke' = (ECApp (ke', t2), loc)
304 val ke' = (EApp (ke', (EFfi ("Basis", "transaction_monad"), loc)), loc)
305 val ke' = (EApp (ke', (EApp (E.liftExpInExp 0 ke, (ERel 0, loc)), loc)), loc)
306 val ke' = (EApp (ke', E.liftExpInExp 0 trans2), loc)
307 val ke' = (EAbs ("x", kdom,
308 (CApp ((CFfi ("Basis", "transaction"), loc), t2), loc),
309 ke'), loc)
310
311 val e' = (ENamed n, loc)
312 val e' = (EApp (e', ke'), loc)
313 val e' = foldl (fn (arg, e') => (EApp (e', arg), loc)) e' args
314 val (e', st) = doExp (e', st)
315 in
316 (#1 e', st)
317 end)
318 | NONE =>
319 case IM.find (#cpsed st, n) of
320 SOME n' => (makeCall n', st)
321 | NONE =>
322 let
323 val (name, fargs, ran, e) =
324 case IM.find (tfuncs, n) of
325 NONE => (Print.prefaces "BAD" [("e",
326 CorePrint.p_exp CoreEnv.empty (e, loc))];
327 raise Fail "Rpcify: Undetected transaction function [2]")
328 | SOME x => x
329
330 val n' = #maxName st
331
332 val st = {cpsed = IM.insert (#cpsed st, n, n'),
333 cpsed_range = IM.insert (#cpsed_range st, n', ran),
334 cps_decls = #cps_decls st,
335 exported = #exported st,
336 export_decls = #export_decls st,
337 maxName = n' + 1}
338
339 val unit = (TRecord (CRecord ((KType, loc), []), loc), loc)
340 val body = (EFfi ("Basis", "bind"), loc)
341 val body = (ECApp (body, (CFfi ("Basis", "transaction"), loc)), loc)
342 val body = (ECApp (body, t1), loc)
343 val body = (ECApp (body, unit), loc)
344 val body = (EApp (body, (EFfi ("Basis", "transaction_monad"), loc)), loc)
345 val body = (EApp (body, e), loc)
346 val body = (EApp (body, (ERel (length args), loc)), loc)
347 val bt = (CApp ((CFfi ("Basis", "transaction"), loc), unit), loc)
348 val (body, bt) = foldr (fn ((x, t), (body, bt)) =>
349 ((EAbs (x, t, bt, body), loc),
350 (TFun (t, bt), loc)))
351 (body, bt) fargs
352 val kt = (TFun (ran, (CApp ((CFfi ("Basis", "transaction"), loc),
353 unit),
354 loc)), loc)
355 val body = (EAbs ("k", kt, bt, body), loc)
356 val bt = (TFun (kt, bt), loc)
357
358 val (body, st) = doExp (body, st)
359
360 val vi = (name ^ "_cps",
361 n',
362 bt,
363 body,
364 "")
365
366 val st = {cpsed = #cpsed st,
367 cpsed_range = #cpsed_range st,
368 cps_decls = vi :: #cps_decls st,
369 exported = #exported st,
370 export_decls = #export_decls st,
371 maxName = #maxName st}
372 in
373 (makeCall n', st)
374 end
375 end
376 | _ => (e, st)
377 end
166 | _ => (e, st) 378 | _ => (e, st)
379
380 and doExp (e, st) = U.Exp.foldMap {kind = fn x => x,
381 con = fn x => x,
382 exp = exp} st (ReduceLocal.reduceExp e)
167 383
168 fun decl (d, st : state) = 384 fun decl (d, st : state) =
169 let 385 let
170 val (d, st) = U.Decl.foldMap {kind = fn x => x, 386 val (d, st) = U.Decl.foldMap {kind = fn x => x,
171 con = fn x => x, 387 con = fn x => x,
179 case d of 395 case d of
180 (DValRec vis, loc) => [(DValRec (ds @ vis), loc)] 396 (DValRec vis, loc) => [(DValRec (ds @ vis), loc)]
181 | (_, loc) => [d, (DValRec ds, loc)], 397 | (_, loc) => [d, (DValRec ds, loc)],
182 #export_decls st), 398 #export_decls st),
183 {cpsed = #cpsed st, 399 {cpsed = #cpsed st,
400 cpsed_range = #cpsed_range st,
184 cps_decls = [], 401 cps_decls = [],
185 402
186 exported = #exported st, 403 exported = #exported st,
187 export_decls = []}) 404 export_decls = [],
405
406 maxName = #maxName st})
188 end 407 end
189 408
190 val (file, _) = ListUtil.foldlMapConcat decl 409 val (file, _) = ListUtil.foldlMapConcat decl
191 {cpsed = IM.empty, 410 {cpsed = IM.empty,
411 cpsed_range = IM.empty,
192 cps_decls = [], 412 cps_decls = [],
193 413
194 exported = IS.empty, 414 exported = IS.empty,
195 export_decls = []} 415 export_decls = [],
416
417 maxName = U.File.maxName file + 1}
196 file 418 file
197 in 419 in
198 file 420 file
199 end 421 end
200 422