comparison src/especialize.sml @ 1079:d069b193ed6b

Especialize uses a termination measure based on number of arguments introduced
author Adam Chlipala <adamc@hcoop.net>
date Tue, 15 Dec 2009 19:26:52 -0500
parents b9321bcefb42
children a4979e31e4bf
comparison
equal deleted inserted replaced
1078:b9321bcefb42 1079:d069b193ed6b
77 pof (pos + 1, ls') 77 pof (pos + 1, ls')
78 in 78 in
79 pof (0, ls) 79 pof (0, ls)
80 end 80 end
81 81
82 fun squish (untouched, fvs) = 82 fun squish fvs =
83 U.Exp.mapB {kind = fn _ => fn k => k, 83 U.Exp.mapB {kind = fn _ => fn k => k,
84 con = fn _ => fn c => c, 84 con = fn _ => fn c => c,
85 exp = fn bound => fn e => 85 exp = fn bound => fn e =>
86 case e of 86 case e of
87 ERel x => 87 ERel x =>
88 if x >= bound then 88 if x >= bound then
89 ERel (positionOf (x - bound, fvs) + bound + untouched) 89 ERel (positionOf (x - bound, fvs) + bound)
90 else 90 else
91 e 91 e
92 | _ => e, 92 | _ => e,
93 bind = fn (bound, b) => 93 bind = fn (bound, b) =>
94 case b of 94 case b of
105 } 105 }
106 106
107 type state = { 107 type state = {
108 maxName : int, 108 maxName : int,
109 funcs : func IM.map, 109 funcs : func IM.map,
110 decls : (string * int * con * exp * string) list 110 decls : (string * int * con * exp * string) list,
111 specialized : bool IM.map
111 } 112 }
112 113
113 fun default (_, x, st) = (x, st) 114 fun default (_, x, st) = (x, st)
114 115
115 structure SS = BinarySetFn(struct 116 structure SS = BinarySetFn(struct
117 val compare = String.compare 118 val compare = String.compare
118 end) 119 end)
119 120
120 val mayNotSpec = ref SS.empty 121 val mayNotSpec = ref SS.empty
121 122
122 fun specialize' file = 123 fun specialize' specialized file =
123 let 124 let
124 fun bind (env, b) = 125 fun bind (env, b) =
125 case b of 126 case b of
126 U.Decl.RelE xt => xt :: env 127 U.Decl.RelE xt => xt :: env
127 | _ => env 128 | _ => env
163 | CFfi ("Basis", "sql_injectable_prim") => true 164 | CFfi ("Basis", "sql_injectable_prim") => true
164 | CFfi ("Basis", "sql_injectable") => true 165 | CFfi ("Basis", "sql_injectable") => true
165 | _ => false} 166 | _ => false}
166 val loc = ErrorMsg.dummySpan 167 val loc = ErrorMsg.dummySpan
167 168
168 fun hasFuncArg t = 169 fun findSplit av (xs, typ, fxs, fvs) =
169 case #1 t of
170 TFun (dom, ran) => functionInside dom orelse hasFuncArg ran
171 | _ => false
172
173 fun findSplit hfa (xs, typ, fxs, fvs, ts) =
174 case (#1 typ, xs) of 170 case (#1 typ, xs) of
175 (TFun (dom, ran), e :: xs') => 171 (TFun (dom, ran), e :: xs') =>
176 let 172 let
177 val isVar = case #1 e of 173 val av = case #1 e of
178 ERel _ => true 174 ERel _ => av
179 | _ => false 175 | _ => false
180 val hfa = hfa andalso isVar
181 in 176 in
182 if hfa orelse functionInside dom then 177 if functionInside dom orelse (av andalso case #1 e of
183 findSplit hfa (xs', 178 ERel _ => true
184 ran, 179 | _ => false) then
185 (true, e) :: fxs, 180 findSplit av (xs',
186 IS.union (fvs, freeVars e), 181 ran,
187 ts) 182 e :: fxs,
183 IS.union (fvs, freeVars e))
188 else 184 else
189 findSplit hfa (xs', ran, (false, e) :: fxs, fvs, dom :: ts) 185 (rev fxs, xs, fvs)
190 end 186 end
191 | _ => (List.revAppend (fxs, map (fn e => (false, e)) xs), fvs, rev ts) 187 | _ => (rev fxs, xs, fvs)
192 188
193 val (xs, fvs, ts) = findSplit (hasFuncArg typ) (xs, typ, [], IS.empty, []) 189 val (fxs, xs, fvs) = findSplit true (xs, typ, [], IS.empty)
194 val fxs = List.mapPartial (fn (true, e) => SOME e | _ => NONE) xs 190
195 val untouched = length (List.filter (fn (false, _) => true | _ => false) xs) 191 val fxs' = map (squish (IS.listItems fvs)) fxs
196 val squish = squish (untouched, IS.listItems fvs)
197 val fxs' = map squish fxs
198 in 192 in
199 (*Print.preface ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs');*) 193 (*Print.preface ("fxs'", Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs');*)
200 if List.all (fn (false, _) => true 194 if List.all (fn (ERel _, _) => true
201 | (true, (ERel _, _)) => true 195 | _ => false) fxs'
202 | _ => false) xs then 196 orelse (IS.numItems fvs >= length fxs
197 andalso IS.exists (fn n => functionInside (#2 (List.nth (env, n)))) fvs) then
203 (e, st) 198 (e, st)
204 else 199 else
205 case (KM.find (args, fxs'), SS.member (!mayNotSpec, name)) of 200 case (KM.find (args, fxs'),
201 SS.member (!mayNotSpec, name) orelse IM.find (#specialized st, f) = SOME true) of
206 (SOME f', _) => 202 (SOME f', _) =>
207 let 203 let
208 val e = (ENamed f', loc) 204 val e = (ENamed f', loc)
209 val e = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc)) 205 val e = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
210 e fvs 206 e fvs
211 val e = foldl (fn ((false, arg), e) => (EApp (e, arg), loc) 207 val e = foldl (fn (arg, e) => (EApp (e, arg), loc))
212 | (_, e) => e)
213 e xs 208 e xs
214 in 209 in
215 (*Print.prefaces "Brand new (reuse)" 210 (*Print.prefaces "Brand new (reuse)"
216 [("e'", CorePrint.p_exp CoreEnv.empty e)];*) 211 [("e'", CorePrint.p_exp CoreEnv.empty e)];*)
217 (#1 e, st) 212 (#1 e, st)
229 224
230 (*val () = Print.prefaces ("Yes(" ^ name ^ ")") 225 (*val () = Print.prefaces ("Yes(" ^ name ^ ")")
231 [("fxs'", 226 [("fxs'",
232 Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*) 227 Print.p_list (CorePrint.p_exp CoreEnv.empty) fxs')]*)
233 228
234 fun subBody (body, typ, xs) = 229 fun subBody (body, typ, fxs') =
235 case (#1 body, #1 typ, xs) of 230 case (#1 body, #1 typ, fxs') of
236 (_, _, []) => SOME (body, typ) 231 (_, _, []) => SOME (body, typ)
237 | (EAbs (_, _, _, body'), TFun (_, typ'), (b, x) :: xs) => 232 | (EAbs (_, _, _, body'), TFun (_, typ'), x :: fxs'') =>
238 let 233 let
239 val body'' = 234 val body'' = E.subExpInExp (0, x) body'
240 if b then
241 E.subExpInExp (0, squish x) body'
242 else
243 body'
244 in 235 in
245 subBody (body'', 236 subBody (body'',
246 typ', 237 typ',
247 xs) 238 fxs'')
248 end 239 end
249 | _ => NONE 240 | _ => NONE
250 in 241 in
251 case subBody (body, typ, xs) of 242 case subBody (body, typ, fxs') of
252 NONE => (e, st) 243 NONE => (e, st)
253 | SOME (body', typ') => 244 | SOME (body', typ') =>
254 let 245 let
255 val f' = #maxName st 246 val f' = #maxName st
256 val args = KM.insert (args, fxs', f') 247 val args = KM.insert (args, fxs', f')
257 val funcs = IM.insert (#funcs st, f, {name = name, 248 val funcs = IM.insert (#funcs st, f, {name = name,
258 args = args, 249 args = args,
259 body = body, 250 body = body,
260 typ = typ, 251 typ = typ,
261 tag = tag}) 252 tag = tag})
253
254 val specialized = IM.insert (#specialized st, f', false)
255 val specialized = case IM.find (specialized, f) of
256 NONE => specialized
257 | SOME _ => IM.insert (specialized, f, true)
258
262 val st = { 259 val st = {
263 maxName = f' + 1, 260 maxName = f' + 1,
264 funcs = funcs, 261 funcs = funcs,
265 decls = #decls st 262 decls = #decls st,
263 specialized = specialized
266 } 264 }
267 265
268 (*val () = Print.prefaces "specExp" 266 (*val () = Print.prefaces "specExp"
269 [("f", CorePrint.p_exp env (ENamed f, loc)), 267 [("f", CorePrint.p_exp env (ENamed f, loc)),
270 ("f'", CorePrint.p_exp env (ENamed f', loc)), 268 ("f'", CorePrint.p_exp env (ENamed f', loc)),
271 ("xs", Print.p_list (CorePrint.p_exp env) xs), 269 ("xs", Print.p_list (CorePrint.p_exp env) xs),
272 ("fxs'", Print.p_list 270 ("fxs'", Print.p_list
273 (CorePrint.p_exp E.empty) fxs'), 271 (CorePrint.p_exp E.empty) fxs'),
274 ("e", CorePrint.p_exp env (e, loc))]*) 272 ("e", CorePrint.p_exp env (e, loc))]*)
275
276 val (body', typ') = foldr (fn (t, (body', typ')) =>
277 ((EAbs ("x", t, typ', body'), loc),
278 (TFun (t, typ'), loc)))
279 (body', typ') ts
280
281 val (body', typ') = IS.foldl (fn (n, (body', typ')) => 273 val (body', typ') = IS.foldl (fn (n, (body', typ')) =>
282 let 274 let
283 val (x, xt) = List.nth (env, n) 275 val (x, xt) = List.nth (env, n)
284 in 276 in
285 ((EAbs (x, xt, typ', body'), 277 ((EAbs (x, xt, typ', body'),
294 val () = mayNotSpec := mns 286 val () = mayNotSpec := mns
295 287
296 val e' = (ENamed f', loc) 288 val e' = (ENamed f', loc)
297 val e' = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc)) 289 val e' = IS.foldr (fn (arg, e) => (EApp (e, (ERel arg, loc)), loc))
298 e' fvs 290 e' fvs
299 val e' = foldl (fn ((false, arg), e) => (EApp (e, arg), loc) 291 val e' = foldl (fn (arg, e) => (EApp (e, arg), loc))
300 | (_, e) => e)
301 e' xs 292 e' xs
302 (*val () = Print.prefaces "Brand new" 293 (*val () = Print.prefaces "Brand new"
303 [("e'", CorePrint.p_exp CoreEnv.empty e'), 294 [("e'", CorePrint.p_exp CoreEnv.empty e'),
304 ("e", CorePrint.p_exp CoreEnv.empty (e, loc)), 295 ("e", CorePrint.p_exp CoreEnv.empty (e, loc)),
305 ("body'", CorePrint.p_exp CoreEnv.empty body')]*) 296 ("body'", CorePrint.p_exp CoreEnv.empty body')]*)
306 in 297 in
307 (#1 e', 298 (#1 e',
308 {maxName = #maxName st, 299 {maxName = #maxName st,
309 funcs = #funcs st, 300 funcs = #funcs st,
310 decls = (name, f', typ', body', tag) :: #decls st}) 301 decls = (name, f', typ', body', tag) :: #decls st,
302 specialized = #specialized st})
311 end 303 end
312 end 304 end
313 end 305 end
314 end 306 end
315 307
334 funcs vis 326 funcs vis
335 | _ => funcs 327 | _ => funcs
336 328
337 val st = {maxName = #maxName st, 329 val st = {maxName = #maxName st,
338 funcs = funcs, 330 funcs = funcs,
339 decls = []} 331 decls = [],
332 specialized = #specialized st}
340 333
341 (*val () = Print.prefaces "decl" [("d", CorePrint.p_decl CoreEnv.empty d)]*) 334 (*val () = Print.prefaces "decl" [("d", CorePrint.p_decl CoreEnv.empty d)]*)
342 335
343 val (d', st) = 336 val (d', st) =
344 if isPoly d then 337 if isPoly d then
379 in 372 in
380 (*Print.prefaces "doDecl" [("d", CorePrint.p_decl E.empty d), 373 (*Print.prefaces "doDecl" [("d", CorePrint.p_decl E.empty d),
381 ("d'", CorePrint.p_decl E.empty d')];*) 374 ("d'", CorePrint.p_decl E.empty d')];*)
382 (ds, ({maxName = #maxName st, 375 (ds, ({maxName = #maxName st,
383 funcs = funcs, 376 funcs = funcs,
384 decls = []}, changed)) 377 decls = [],
378 specialized = #specialized st}, changed))
385 end 379 end
386 380
387 val (ds, (_, changed)) = ListUtil.foldlMapConcat doDecl 381 val (ds, (st, changed)) = ListUtil.foldlMapConcat doDecl
388 ({maxName = U.File.maxName file + 1, 382 ({maxName = U.File.maxName file + 1,
389 funcs = IM.empty, 383 funcs = IM.empty,
390 decls = []}, 384 decls = [],
385 specialized = specialized},
391 false) 386 false)
392 file 387 file
393 in 388 in
394 (changed, ds) 389 (changed, ds, #specialized st)
395 end 390 end
396 391
397 fun specialize file = 392 fun specializeL specialized file =
398 let 393 let
399 val file = ReduceLocal.reduce file 394 val file = ReduceLocal.reduce file
400 (*val () = Print.prefaces "Intermediate" [("file", CorePrint.p_file CoreEnv.empty file)]*) 395 (*val () = Print.prefaces "Intermediate" [("file", CorePrint.p_file CoreEnv.empty file)]*)
401 (*val file = ReduceLocal.reduce file*) 396 (*val file = ReduceLocal.reduce file*)
402 val (changed, file) = specialize' file 397 val (changed, file, specialized) = specialize' specialized file
403 (*val file = ReduceLocal.reduce file 398 (*val file = ReduceLocal.reduce file
404 val file = CoreUntangle.untangle file 399 val file = CoreUntangle.untangle file
405 val file = Shake.shake file*) 400 val file = Shake.shake file*)
406 in 401 in
407 (*print "Round over\n";*) 402 (*print "Round over\n";*)
412 val file = CoreUntangle.untangle file 407 val file = CoreUntangle.untangle file
413 (*val () = Print.prefaces "Post-untangle" [("file", CorePrint.p_file CoreEnv.empty file)]*) 408 (*val () = Print.prefaces "Post-untangle" [("file", CorePrint.p_file CoreEnv.empty file)]*)
414 val file = Shake.shake file 409 val file = Shake.shake file
415 in 410 in
416 (*print "Again!\n";*) 411 (*print "Again!\n";*)
417 specialize file 412 specializeL specialized file
418 end 413 end
419 else 414 else
420 file 415 file
421 end 416 end
422 417
418 val specialize = specializeL IM.empty
419
423 end 420 end