comparison src/especialize.sml @ 1080:a4979e31e4bf

Another try at reasonable Especialize, this time with a custom traversal
author Adam Chlipala <adamc@hcoop.net>
date Sun, 20 Dec 2009 15:17:43 -0500
parents d069b193ed6b
children 2eb585274501
comparison
equal deleted inserted replaced
1079:d069b193ed6b 1080:a4979e31e4bf
1 (* Copyright (c) 2008, Adam Chlipala 1 (* Copyright (c) 2008-2009, Adam Chlipala
2 * All rights reserved. 2 * All rights reserved.
3 * 3 *
4 * Redistribution and use in source and binary forms, with or without 4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions are met: 5 * modification, are permitted provided that the following conditions are met:
6 * 6 *
60 0 IS.empty 60 0 IS.empty
61 61
62 val isPoly = U.Decl.exists {kind = fn _ => false, 62 val isPoly = U.Decl.exists {kind = fn _ => false,
63 con = fn _ => false, 63 con = fn _ => false,
64 exp = fn ECAbs _ => true 64 exp = fn ECAbs _ => true
65 | EKAbs _ => true
65 | _ => false, 66 | _ => false,
66 decl = fn _ => false} 67 decl = fn _ => false}
67 68
68 fun positionOf (v : int, ls) = 69 fun positionOf (v : int, ls) =
69 let 70 let
106 107
107 type state = { 108 type state = {
108 maxName : int, 109 maxName : int,
109 funcs : func IM.map, 110 funcs : func IM.map,
110 decls : (string * int * con * exp * string) list, 111 decls : (string * int * con * exp * string) list,
111 specialized : bool IM.map 112 specialized : IS.set
112 } 113 }
113 114
114 fun default (_, x, st) = (x, st) 115 fun default (_, x, st) = (x, st)
115 116
116 structure SS = BinarySetFn(struct 117 structure SS = BinarySetFn(struct
118 val compare = String.compare 119 val compare = String.compare
119 end) 120 end)
120 121
121 val mayNotSpec = ref SS.empty 122 val mayNotSpec = ref SS.empty
122 123
123 fun specialize' specialized file = 124 fun specialize' (funcs, specialized) file =
124 let 125 let
125 fun bind (env, b) = 126 fun bind (env, b) =
126 case b of 127 case b of
127 U.Decl.RelE xt => xt :: env 128 U.Decl.RelE xt => xt :: env
128 | _ => env 129 | _ => env
129 130
130 fun exp (env, e, st : state) = 131 fun exp (env, e as (_, loc), st : state) =
131 let 132 let
132 (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty 133 (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty
133 (e, ErrorMsg.dummySpan))]*) 134 (e, ErrorMsg.dummySpan))]*)
134 135
135 fun getApp e = 136 fun getApp (e, _) =
136 case e of 137 case e of
137 ENamed f => SOME (f, []) 138 ENamed f => SOME (f, [])
138 | EApp (e1, e2) => 139 | EApp (e1, e2) =>
139 (case getApp (#1 e1) of 140 (case getApp e1 of
140 NONE => NONE 141 NONE => NONE
141 | SOME (f, xs) => SOME (f, xs @ [e2])) 142 | SOME (f, xs) => SOME (f, xs @ [e2]))
142 | _ => NONE 143 | _ => NONE
144
145 val getApp = fn e => case getApp e of
146 v as SOME (_, _ :: _) => v
147 | _ => NONE
148
149 fun default () =
150 case #1 e of
151 EPrim _ => (e, st)
152 | ERel _ => (e, st)
153 | ENamed _ => (e, st)
154 | ECon (_, _, _, NONE) => (e, st)
155 | ECon (dk, pc, cs, SOME e) =>
156 let
157 val (e, st) = exp (env, e, st)
158 in
159 ((ECon (dk, pc, cs, SOME e), loc), st)
160 end
161 | EFfi _ => (e, st)
162 | EFfiApp (m, x, es) =>
163 let
164 val (es, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st es
165 in
166 ((EFfiApp (m, x, es), loc), st)
167 end
168 | EApp (e1, e2) =>
169 let
170 val (e1, st) = exp (env, e1, st)
171 val (e2, st) = exp (env, e2, st)
172 in
173 ((EApp (e1, e2), loc), st)
174 end
175 | EAbs (x, d, r, e) =>
176 let
177 val (e, st) = exp ((x, d) :: env, e, st)
178 in
179 ((EAbs (x, d, r, e), loc), st)
180 end
181 | ECApp (e, c) =>
182 let
183 val (e, st) = exp (env, e, st)
184 in
185 ((ECApp (e, c), loc), st)
186 end
187 | ECAbs _ => raise Fail "Especialize: Impossible ECAbs"
188 | EKAbs _ => raise Fail "Especialize: Impossible EKAbs"
189 | EKApp (e, k) =>
190 let
191 val (e, st) = exp (env, e, st)
192 in
193 ((EKApp (e, k), loc), st)
194 end
195 | ERecord fs =>
196 let
197 val (fs, st) = ListUtil.foldlMap (fn ((c1, e, c2), st) =>
198 let
199 val (e, st) = exp (env, e, st)
200 in
201 ((c1, e, c2), st)
202 end) st fs
203 in
204 ((ERecord fs, loc), st)
205 end
206 | EField (e, c, cs) =>
207 let
208 val (e, st) = exp (env, e, st)
209 in
210 ((EField (e, c, cs), loc), st)
211 end
212 | EConcat (e1, c1, e2, c2) =>
213 let
214 val (e1, st) = exp (env, e1, st)
215 val (e2, st) = exp (env, e2, st)
216 in
217 ((EConcat (e1, c1, e2, c2), loc), st)
218 end
219 | ECut (e, c, cs) =>
220 let
221 val (e, st) = exp (env, e, st)
222 in
223 ((ECut (e, c, cs), loc), st)
224 end
225 | ECutMulti (e, c, cs) =>
226 let
227 val (e, st) = exp (env, e, st)
228 in
229 ((ECutMulti (e, c, cs), loc), st)
230 end
231
232 | ECase (e, pes, cs) =>
233 let
234 val (e, st) = exp (env, e, st)
235 val (pes, st) = ListUtil.foldlMap (fn ((p, e), st) =>
236 let
237 val (e, st) = exp (E.patBindsL p @ env, e, st)
238 in
239 ((p, e), st)
240 end) st pes
241 in
242 ((ECase (e, pes, cs), loc), st)
243 end
244
245 | EWrite e =>
246 let
247 val (e, st) = exp (env, e, st)
248 in
249 ((EWrite e, loc), st)
250 end
251 | EClosure (n, es) =>
252 let
253 val (es, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st es
254 in
255 ((EClosure (n, es), loc), st)
256 end
257 | ELet (x, t, e1, e2) =>
258 let
259 val (e1, st) = exp (env, e1, st)
260 val (e2, st) = exp ((x, t) :: env, e2, st)
261 in
262 ((ELet (x, t, e1, e2), loc), st)
263 end
264 | EServerCall (n, es, t) =>
265 let
266 val (es, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st es
267 in
268 ((EServerCall (n, es, t), loc), st)
269 end
143 in 270 in
144 case getApp e of 271 case getApp e of
145 NONE => ((*Print.prefaces "No" [("e", CorePrint.p_exp CoreEnv.empty 272 NONE => default ()
146 (e, ErrorMsg.dummySpan))];*)
147 (e, st))
148 | SOME (f, xs) => 273 | SOME (f, xs) =>
149 case IM.find (#funcs st, f) of 274 case IM.find (#funcs st, f) of
150 NONE => (e, st) 275 NONE => default ()
151 | SOME {name, args, body, typ, tag} => 276 | SOME {name, args, body, typ, tag} =>
152 let 277 let
278 val (xs, st) = ListUtil.foldlMap (fn (e, st) => exp (env, e, st)) st xs
279
153 (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty 280 (*val () = Print.prefaces "Consider" [("e", CorePrint.p_exp CoreEnv.empty
154 (e, ErrorMsg.dummySpan))]*) 281 (e, ErrorMsg.dummySpan))]*)
155 282
156 val functionInside = U.Con.exists {kind = fn _ => false, 283 val functionInside = U.Con.exists {kind = fn _ => false,
157 con = fn TFun _ => true 284 con = fn TFun _ => true
164 | CFfi ("Basis", "sql_injectable_prim") => true 291 | CFfi ("Basis", "sql_injectable_prim") => true
165 | CFfi ("Basis", "sql_injectable") => true 292 | CFfi ("Basis", "sql_injectable") => true
166 | _ => false} 293 | _ => false}
167 val loc = ErrorMsg.dummySpan 294 val loc = ErrorMsg.dummySpan
168 295
169 fun findSplit av (xs, typ, fxs, fvs) = 296 fun findSplit av (xs, typ, fxs, fvs, fin) =
170 case (#1 typ, xs) of 297 case (#1 typ, xs) of
171 (TFun (dom, ran), e :: xs') => 298 (TFun (dom, ran), e :: xs') =>
172 let 299 let
173 val av = case #1 e of 300 val av = case #1 e of
174 ERel _ => av 301 ERel _ => av
178 ERel _ => true 305 ERel _ => true
179 | _ => false) then 306 | _ => false) then
180 findSplit av (xs', 307 findSplit av (xs',
181 ran, 308 ran,
182 e :: fxs, 309 e :: fxs,
183 IS.union (fvs, freeVars e)) 310 IS.union (fvs, freeVars e),
311 fin orelse functionInside dom)
184 else 312 else
185 (rev fxs, xs, fvs) 313 (rev fxs, xs, fvs, fin)
186 end 314 end
187 | _ => (rev fxs, xs, fvs) 315 | _ => (rev fxs, xs, fvs, fin)
188 316
189 val (fxs, xs, fvs) = findSplit true (xs, typ, [], IS.empty) 317 val (fxs, xs, fvs, fin) = findSplit true (xs, typ, [], IS.empty, false)
190 318
191 val fxs' = map (squish (IS.listItems fvs)) fxs 319 val fxs' = map (squish (IS.listItems fvs)) fxs
192 in 320 in
193 (*Print.preface ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs');*) 321 (*Print.preface ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs');*)
194 if List.all (fn (ERel _, _) => true 322 if not fin
195 | _ => false) fxs' 323 orelse List.all (fn (ERel _, _) => true
324 | _ => false) fxs'
196 orelse (IS.numItems fvs >= length fxs 325 orelse (IS.numItems fvs >= length fxs
197 andalso IS.exists (fn n => functionInside (#2 (List.nth (env, n)))) fvs) then 326 andalso IS.exists (fn n => functionInside (#2 (List.nth (env, n)))) fvs) then
198 (e, st) 327 default ()
199 else 328 else
200 case (KM.find (args, fxs'), 329 case (KM.find (args, fxs'),
201 SS.member (!mayNotSpec, name) orelse IM.find (#specialized st, f) = SOME true) of 330 SS.member (!mayNotSpec, name) orelse IS.member (#specialized st, f)) of
202 (SOME f', _) => 331 (SOME f', _) =>
203 let 332 let
204 val e = (ENamed f', loc) 333 val e = (ENamed f', loc)
205 val e = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc)) 334 val e = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
206 e fvs 335 e fvs
207 val e = foldl (fn (arg, e) => (EApp (e, arg), loc)) 336 val e = foldl (fn (arg, e) => (EApp (e, arg), loc))
208 e xs 337 e xs
209 in 338 in
210 (*Print.prefaces "Brand new (reuse)" 339 (*Print.prefaces "Brand new (reuse)"
211 [("e'", CorePrint.p_exp CoreEnv.empty e)];*) 340 [("e'", CorePrint.p_exp CoreEnv.empty e)];*)
212 (#1 e, st) 341 (e, st)
213 end 342 end
214 | (_, true) => ((*Print.prefaces ("No(" ^ name ^ ")") 343 | (_, true) => ((*Print.prefaces ("No(" ^ name ^ ")")
215 [("fxs'", 344 [("fxs'",
216 Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')];*) 345 Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')];*)
217 (e, st)) 346 default ())
218 | (NONE, false) => 347 | (NONE, false) =>
219 let 348 let
220 (*val () = Print.prefaces "New one" 349 (*val () = Print.prefaces "New one"
221 [("f", Print.PD.string (Int.toString f)), 350 [("f", Print.PD.string (Int.toString f)),
222 ("mns", Print.p_list Print.PD.string 351 ("mns", Print.p_list Print.PD.string
238 fxs'') 367 fxs'')
239 end 368 end
240 | _ => NONE 369 | _ => NONE
241 in 370 in
242 case subBody (body, typ, fxs') of 371 case subBody (body, typ, fxs') of
243 NONE => (e, st) 372 NONE => default ()
244 | SOME (body', typ') => 373 | SOME (body', typ') =>
245 let 374 let
246 val f' = #maxName st 375 val f' = #maxName st
247 val args = KM.insert (args, fxs', f') 376 val args = KM.insert (args, fxs', f')
248 val funcs = IM.insert (#funcs st, f, {name = name, 377 val funcs = IM.insert (#funcs st, f, {name = name,
249 args = args, 378 args = args,
250 body = body, 379 body = body,
251 typ = typ, 380 typ = typ,
252 tag = tag}) 381 tag = tag})
253 382
254 val specialized = IM.insert (#specialized st, f', false)
255 val specialized = case IM.find (specialized, f) of
256 NONE => specialized
257 | SOME _ => IM.insert (specialized, f, true)
258
259 val st = { 383 val st = {
260 maxName = f' + 1, 384 maxName = f' + 1,
261 funcs = funcs, 385 funcs = funcs,
262 decls = #decls st, 386 decls = #decls st,
263 specialized = specialized 387 specialized = IS.add (#specialized st, f')
264 } 388 }
265 389
266 (*val () = Print.prefaces "specExp" 390 (*val () = Print.prefaces "specExp"
267 [("f", CorePrint.p_exp env (ENamed f, loc)), 391 [("f", CorePrint.p_exp env (ENamed f, loc)),
268 ("f'", CorePrint.p_exp env (ENamed f', loc)), 392 ("f'", CorePrint.p_exp env (ENamed f', loc)),
278 loc), 402 loc),
279 (TFun (xt, typ'), loc)) 403 (TFun (xt, typ'), loc))
280 end) 404 end)
281 (body', typ') fvs 405 (body', typ') fvs
282 val mns = !mayNotSpec 406 val mns = !mayNotSpec
283 val () = mayNotSpec := SS.add (mns, name) 407 (*val () = mayNotSpec := SS.add (mns, name)*)
284 (*val () = Print.preface ("body'", CorePrint.p_exp CoreEnv.empty body')*) 408 (*val () = Print.preface ("PRE", CorePrint.p_exp CoreEnv.empty body')*)
285 val (body', st) = specExp env st body' 409 val (body', st) = exp (env, body', st)
286 val () = mayNotSpec := mns 410 val () = mayNotSpec := mns
287 411
288 val e' = (ENamed f', loc) 412 val e' = (ENamed f', loc)
289 val e' = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc)) 413 val e' = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
290 e' fvs 414 e' fvs
291 val e' = foldl (fn (arg, e) => (EApp (e, arg), loc)) 415 val e' = foldl (fn (arg, e) => (EApp (e, arg), loc))
292 e' xs 416 e' xs
293 (*val () = Print.prefaces "Brand new" 417 (*val () = Print.prefaces "Brand new"
294 [("e'", CorePrint.p_exp CoreEnv.empty e'), 418 [("e'", CorePrint.p_exp CoreEnv.empty e'),
295 ("e", CorePrint.p_exp CoreEnv.empty (e, loc)), 419 ("e", CorePrint.p_exp CoreEnv.empty e),
296 ("body'", CorePrint.p_exp CoreEnv.empty body')]*) 420 ("body'", CorePrint.p_exp CoreEnv.empty body')]*)
297 in 421 in
298 (#1 e', 422 (e',
299 {maxName = #maxName st, 423 {maxName = #maxName st,
300 funcs = #funcs st, 424 funcs = #funcs st,
301 decls = (name, f', typ', body', tag) :: #decls st, 425 decls = (name, f', typ', body', tag) :: #decls st,
302 specialized = #specialized st}) 426 specialized = #specialized st})
303 end 427 end
304 end 428 end
305 end 429 end
306 end 430 end
307
308 and specExp env = U.Exp.foldMapB {kind = default, con = default, exp = exp, bind = bind} env
309
310 val specDecl = U.Decl.foldMapB {kind = default, con = default, exp = exp, decl = default, bind = bind}
311 431
312 fun doDecl (d, (st : state, changed)) = 432 fun doDecl (d, (st : state, changed)) =
313 let 433 let
314 (*val befor = Time.now ()*) 434 (*val befor = Time.now ()*)
315 435
331 decls = [], 451 decls = [],
332 specialized = #specialized st} 452 specialized = #specialized st}
333 453
334 (*val () = Print.prefaces "decl" [("d", CorePrint.p_decl CoreEnv.empty d)]*) 454 (*val () = Print.prefaces "decl" [("d", CorePrint.p_decl CoreEnv.empty d)]*)
335 455
456 val () = mayNotSpec := SS.empty
457
336 val (d', st) = 458 val (d', st) =
337 if isPoly d then 459 if isPoly d then
338 (d, st) 460 (d, st)
339 else 461 else
340 (mayNotSpec := SS.empty(*(case #1 d of 462 case #1 d of
341 DValRec vis => foldl (fn ((x, _, _, _, _), mns) => 463 DVal (x, n, t, e, s) =>
342 SS.add (mns, x)) SS.empty vis 464 let
343 | DVal (x, _, _, _, _) => SS.singleton x 465 val (e, st) = exp ([], e, st)
344 | _ => SS.empty)*); 466 in
345 specDecl [] st d 467 ((DVal (x, n, t, e, s), #2 d), st)
346 before mayNotSpec := SS.empty) 468 end
469 | DValRec vis =>
470 let
471 val (vis, st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) =>
472 let
473 val (e, st) = exp ([], e, st)
474 in
475 ((x, n, t, e, s), st)
476 end) st vis
477 in
478 ((DValRec vis, #2 d), st)
479 end
480 | DTable (s, n, t, s1, e1, t1, e2, t2) =>
481 let
482 val (e1, st) = exp ([], e1, st)
483 val (e2, st) = exp ([], e2, st)
484 in
485 ((DTable (s, n, t, s1, e1, t2, e2, t2), #2 d), st)
486 end
487 | DView (x, n, s, e, t) =>
488 let
489 val (e, st) = exp ([], e, st)
490 in
491 ((DView (x, n, s, e, t), #2 d), st)
492 end
493 | DTask (e1, e2) =>
494 let
495 val (e1, st) = exp ([], e1, st)
496 val (e2, st) = exp ([], e2, st)
497 in
498 ((DTask (e1, e2), #2 d), st)
499 end
500 | _ => (d, st)
501
502 val () = mayNotSpec := SS.empty
347 503
348 (*val () = print "/decl\n"*) 504 (*val () = print "/decl\n"*)
349 505
350 val funcs = #funcs st 506 val funcs = #funcs st
351 val funcs = 507 val funcs =
378 specialized = #specialized st}, changed)) 534 specialized = #specialized st}, changed))
379 end 535 end
380 536
381 val (ds, (st, changed)) = ListUtil.foldlMapConcat doDecl 537 val (ds, (st, changed)) = ListUtil.foldlMapConcat doDecl
382 ({maxName = U.File.maxName file + 1, 538 ({maxName = U.File.maxName file + 1,
383 funcs = IM.empty, 539 funcs = funcs,
384 decls = [], 540 decls = [],
385 specialized = specialized}, 541 specialized = specialized},
386 false) 542 false)
387 file 543 file
388 in 544 in
389 (changed, ds, #specialized st) 545 (changed, ds, #funcs st, #specialized st)
390 end 546 end
391 547
392 fun specializeL specialized file = 548 fun specializeL (funcs, specialized) file =
393 let 549 let
394 val file = ReduceLocal.reduce file 550 val file = ReduceLocal.reduce file
395 (*val () = Print.prefaces "Intermediate" [("file", CorePrint.p_file CoreEnv.empty file)]*)
396 (*val file = ReduceLocal.reduce file*) 551 (*val file = ReduceLocal.reduce file*)
397 val (changed, file, specialized) = specialize' specialized file 552 val (changed, file, funcs, specialized) = specialize' (funcs, specialized) file
398 (*val file = ReduceLocal.reduce file 553 (*val file = ReduceLocal.reduce file
399 val file = CoreUntangle.untangle file 554 val file = CoreUntangle.untangle file
400 val file = Shake.shake file*) 555 val file = Shake.shake file*)
401 in 556 in
402 (*print "Round over\n";*) 557 (*print "Round over\n";*)
407 val file = CoreUntangle.untangle file 562 val file = CoreUntangle.untangle file
408 (*val () = Print.prefaces "Post-untangle" [("file", CorePrint.p_file CoreEnv.empty file)]*) 563 (*val () = Print.prefaces "Post-untangle" [("file", CorePrint.p_file CoreEnv.empty file)]*)
409 val file = Shake.shake file 564 val file = Shake.shake file
410 in 565 in
411 (*print "Again!\n";*) 566 (*print "Again!\n";*)
412 specializeL specialized file 567 (*Print.prefaces "Again" [("file", CorePrint.p_file CoreEnv.empty file)];*)
568 specializeL (funcs, specialized) file
413 end 569 end
414 else 570 else
415 file 571 file
416 end 572 end
417 573
418 val specialize = specializeL IM.empty 574 val specialize = specializeL (IM.empty, IS.empty)
419 575
420 end 576 end