comparison src/elaborate.sml @ 819:cb30dd2ba353

Switch to Maranget's pattern exhaustiveness algorithm
author Adam Chlipala <adamc@hcoop.net>
date Sat, 23 May 2009 09:45:02 -0400
parents e2780d2f4afc
children 395a5d450cc0
comparison
equal deleted inserted replaced
818:066493f7f008 819:cb30dd2ba353
36 36
37 open Print 37 open Print
38 open ElabPrint 38 open ElabPrint
39 open ElabErr 39 open ElabErr
40 40
41 structure IS = IntBinarySet
41 structure IM = IntBinaryMap 42 structure IM = IntBinaryMap
42 43
43 structure SK = struct 44 structure SK = struct
44 type ord_key = string 45 type ord_key = string
45 val compare = String.compare 46 val compare = String.compare
1289 (env, bound)) 1290 (env, bound))
1290 end 1291 end
1291 1292
1292 end 1293 end
1293 1294
1294 datatype coverage = 1295 (* This exhaustiveness checking follows Luc Maranget's paper "Warnings for pattern matching." *)
1296 fun exhaustive (env, t, ps, loc) =
1297 let
1298 fun fail n = raise Fail ("Elaborate.exhaustive: Impossible " ^ Int.toString n)
1299
1300 fun patConNum pc =
1301 case pc of
1302 L'.PConVar n => n
1303 | L'.PConProj (m1, ms, x) =>
1304 let
1305 val (str, sgn) = E.chaseMpath env (m1, ms)
1306 in
1307 case E.projectConstructor env {str = str, sgn = sgn, field = x} of
1308 NONE => raise Fail "exhaustive: Can't project datatype"
1309 | SOME (_, n, _, _, _) => n
1310 end
1311
1312 fun nameOfNum (t, n) =
1313 case t of
1314 L'.CModProj (m1, ms, x) =>
1315 let
1316 val (str, sgn) = E.chaseMpath env (m1, ms)
1317 in
1318 case E.projectDatatype env {str = str, sgn = sgn, field = x} of
1319 NONE => raise Fail "exhaustive: Can't project datatype"
1320 | SOME (_, cons) =>
1321 case ListUtil.search (fn (name, n', _) =>
1322 if n' = n then
1323 SOME name
1324 else
1325 NONE) cons of
1326 NONE => fail 9
1327 | SOME name => L'.PConProj (m1, ms, name)
1328 end
1329 | _ => L'.PConVar n
1330
1331 fun S (args, c, P) =
1332 List.mapPartial
1333 (fn [] => fail 1
1334 | p1 :: ps =>
1335 let
1336 val loc = #2 p1
1337
1338 fun wild () =
1339 SOME (map (fn _ => (L'.PWild, loc)) args @ ps)
1340 in
1341 case #1 p1 of
1342 L'.PPrim _ => NONE
1343 | L'.PCon (_, c', _, NONE) =>
1344 if patConNum c' = c then
1345 SOME ps
1346 else
1347 NONE
1348 | L'.PCon (_, c', _, SOME p) =>
1349 if patConNum c' = c then
1350 SOME (p :: ps)
1351 else
1352 NONE
1353 | L'.PRecord xpts =>
1354 SOME (map (fn x =>
1355 case ListUtil.search (fn (x', p, _) =>
1356 if x = x' then
1357 SOME p
1358 else
1359 NONE) xpts of
1360 NONE => (L'.PWild, loc)
1361 | SOME p => p) args @ ps)
1362 | L'.PWild => wild ()
1363 | L'.PVar _ => wild ()
1364 end)
1365 P
1366
1367 fun D P =
1368 List.mapPartial
1369 (fn [] => fail 2
1370 | (p1, _) :: ps =>
1371 case p1 of
1372 L'.PWild => SOME ps
1373 | L'.PVar _ => SOME ps
1374 | L'.PPrim _ => NONE
1375 | L'.PCon _ => NONE
1376 | L'.PRecord _ => NONE)
1377 P
1378
1379 fun I (P, q) =
1380 (*(prefaces "I" [("P", p_list (fn P' => box [PD.string "[", p_list (p_pat env) P', PD.string "]"]) P),
1381 ("q", p_list (p_con env) q)];*)
1382 case q of
1383 [] => (case P of
1384 [] => SOME []
1385 | _ => NONE)
1386 | q1 :: qs =>
1387 let
1388 val loc = #2 q1
1389
1390 fun unapp (t, acc) =
1391 case t of
1392 L'.CApp ((t, _), arg) => unapp (t, arg :: acc)
1393 | _ => (t, rev acc)
1394
1395 val (t1, args) = unapp (#1 (hnormCon env q1), [])
1396 fun doSub t = foldl (fn (arg, t) => subConInCon (0, arg) t) t args
1397
1398 fun dtype (dtO, names) =
1399 let
1400 val nameSet = IS.addList (IS.empty, names)
1401 val nameSet = foldl (fn (ps, nameSet) =>
1402 case ps of
1403 [] => fail 4
1404 | (L'.PCon (_, pc, _, _), _) :: _ =>
1405 (IS.delete (nameSet, patConNum pc)
1406 handle NotFound => nameSet)
1407 | _ => nameSet)
1408 nameSet P
1409 in
1410 nameSet
1411 end
1412
1413 fun default () = (NONE, IS.singleton 0, [])
1414
1415 val (dtO, unused, cons) =
1416 case t1 of
1417 L'.CNamed n =>
1418 let
1419 val dt = E.lookupDatatype env n
1420 val cons = E.constructors dt
1421 in
1422 (SOME dt,
1423 dtype (SOME dt, map #2 cons),
1424 map (fn (_, n, co) =>
1425 (n,
1426 case co of
1427 NONE => []
1428 | SOME t => [("", doSub t)])) cons)
1429 end
1430 | L'.CModProj (m1, ms, x) =>
1431 let
1432 val (str, sgn) = E.chaseMpath env (m1, ms)
1433 in
1434 case E.projectDatatype env {str = str, sgn = sgn, field = x} of
1435 NONE => default ()
1436 | SOME (_, cons) =>
1437 (NONE,
1438 dtype (NONE, map #2 cons),
1439 map (fn (s, _, co) =>
1440 (patConNum (L'.PConProj (m1, ms, s)),
1441 case co of
1442 NONE => []
1443 | SOME t => [("", doSub t)])) cons)
1444 end
1445 | L'.TRecord (L'.CRecord (_, xts), _) =>
1446 let
1447 val xts = map (fn ((L'.CName x, _), co) => SOME (x, co)
1448 | _ => NONE) xts
1449 in
1450 if List.all Option.isSome xts then
1451 let
1452 val xts = List.mapPartial (fn x => x) xts
1453 val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) =>
1454 String.compare (x1, x2) = GREATER) xts
1455 in
1456 (NONE, IS.empty, [(0, xts)])
1457 end
1458 else
1459 default ()
1460 end
1461 | _ => default ()
1462 in
1463 if IS.isEmpty unused then
1464 let
1465 fun recurse cons =
1466 case cons of
1467 [] => NONE
1468 | (name, args) :: cons =>
1469 case I (S (map #1 args, name, P),
1470 map #2 args @ qs) of
1471 NONE => recurse cons
1472 | SOME ps =>
1473 let
1474 val nargs = length args
1475 val argPs = List.take (ps, nargs)
1476 val restPs = List.drop (ps, nargs)
1477
1478 val p = case name of
1479 0 => L'.PRecord (ListPair.map
1480 (fn ((name, t), p) => (name, p, t))
1481 (args, argPs))
1482 | _ => L'.PCon (L'.Default, nameOfNum (t1, name), [],
1483 case argPs of
1484 [] => NONE
1485 | [p] => SOME p
1486 | _ => fail 3)
1487 in
1488 SOME ((p, loc) :: restPs)
1489 end
1490 in
1491 recurse cons
1492 end
1493 else
1494 case I (D P, qs) of
1495 NONE => NONE
1496 | SOME ps =>
1497 let
1498 val p = case cons of
1499 [] => L'.PWild
1500 | (0, _) :: _ => L'.PWild
1501 | _ =>
1502 case IS.find (fn _ => true) unused of
1503 NONE => fail 6
1504 | SOME name =>
1505 case ListUtil.search (fn (name', args) =>
1506 if name = name' then
1507 SOME (name', args)
1508 else
1509 NONE) cons of
1510 SOME (n, []) =>
1511 L'.PCon (L'.Default, nameOfNum (t1, n), [], NONE)
1512 | SOME (n, [_]) =>
1513 L'.PCon (L'.Default, nameOfNum (t1, n), [], SOME (L'.PWild, loc))
1514 | _ => fail 7
1515 in
1516 SOME ((p, loc) :: ps)
1517 end
1518 end
1519 in
1520 case I (map (fn x => [x]) ps, [t]) of
1521 NONE => NONE
1522 | SOME [p] => SOME p
1523 | _ => fail 7
1524 end
1525
1526 (*datatype coverage =
1295 Wild 1527 Wild
1296 | None 1528 | None
1297 | Datatype of coverage IM.map 1529 | Datatype of coverage IM.map
1298 | Record of coverage SM.map list 1530 | Record of coverage SM.map list
1299 1531
1358 [] => raise Fail "Empty pattern list for coverage checking" 1590 [] => raise Fail "Empty pattern list for coverage checking"
1359 | [p] => coverage p 1591 | [p] => coverage p
1360 | p :: ps => merge (coverage p, combinedCoverage ps) 1592 | p :: ps => merge (coverage p, combinedCoverage ps)
1361 1593
1362 fun enumerateCases depth t = 1594 fun enumerateCases depth t =
1363 if depth = 0 then 1595 (TextIO.print "enum'\n"; if depth <= 0 then
1364 [Wild] 1596 [Wild]
1365 else 1597 else
1366 let 1598 let
1367 fun dtype cons = 1599 val dtype =
1368 ListUtil.mapConcat (fn (_, n, to) => 1600 ListUtil.mapConcat (fn (_, n, to) =>
1369 case to of 1601 case to of
1370 NONE => [Datatype (IM.insert (IM.empty, n, Wild))] 1602 NONE => [Datatype (IM.insert (IM.empty, n, Wild))]
1371 | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c))) 1603 | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c)))
1372 (enumerateCases (depth-1) t)) cons 1604 (enumerateCases (depth-1) t))
1373 in 1605 in
1374 case #1 (hnormCon env t) of 1606 case #1 (hnormCon env t) of
1375 L'.CNamed n => 1607 L'.CNamed n =>
1376 (let 1608 (let
1377 val dt = E.lookupDatatype env n 1609 val dt = E.lookupDatatype env n
1391 | ((L'.CName x, _), t) :: rest => 1623 | ((L'.CName x, _), t) :: rest =>
1392 let 1624 let
1393 val this = enumerateCases depth t 1625 val this = enumerateCases depth t
1394 val rest = exponentiate rest 1626 val rest = exponentiate rest
1395 in 1627 in
1628 TextIO.print ("Before (" ^ Int.toString (length this)
1629 ^ ", " ^ Int.toString (length rest) ^ ")\n");
1396 ListUtil.mapConcat (fn fmap => 1630 ListUtil.mapConcat (fn fmap =>
1397 map (fn c => SM.insert (fmap, x, c)) this) rest 1631 map (fn c => SM.insert (fmap, x, c)) this) rest
1632 before TextIO.print "After\n"
1398 end 1633 end
1399 | _ => raise Fail "exponentiate: Not CName" 1634 | _ => raise Fail "exponentiate: Not CName"
1400 in 1635 in
1401 if List.exists (fn ((L'.CName _, _), _) => false 1636 if List.exists (fn ((L'.CName _, _), _) => false
1402 | (c, _) => true) xts then 1637 | (c, _) => true) xts then
1404 else 1639 else
1405 map (fn ls => Record [ls]) (exponentiate xts) 1640 map (fn ls => Record [ls]) (exponentiate xts)
1406 end 1641 end
1407 | _ => [Wild]) 1642 | _ => [Wild])
1408 | _ => [Wild] 1643 | _ => [Wild]
1409 end 1644 end before TextIO.print "/enum'\n")
1410 1645
1411 fun coverageImp (c1, c2) = 1646 fun coverageImp (c1, c2) =
1412 let 1647 let
1413 val r = 1648 val r =
1414 case (c1, c2) of 1649 case (c1, c2) of
1485 | c => 1720 | c =>
1486 (prefaces "Not a datatype" [("loc", PD.string (ErrorMsg.spanToString loc)), 1721 (prefaces "Not a datatype" [("loc", PD.string (ErrorMsg.spanToString loc)),
1487 ("c", p_con env (c, ErrorMsg.dummySpan))]; 1722 ("c", p_con env (c, ErrorMsg.dummySpan))];
1488 raise Fail "isTotal: Not a datatype") 1723 raise Fail "isTotal: Not a datatype")
1489 end 1724 end
1490 | Record _ => List.all (fn c2 => coverageImp (c, c2)) (enumerateCases depth t) 1725 | Record _ => List.all (fn c2 => coverageImp (c, c2))
1726 (TextIO.print "enum\n"; enumerateCases depth t before TextIO.print "/enum\n")
1491 in 1727 in
1492 isTotal (combinedCoverage ps, t) 1728 isTotal (combinedCoverage ps, t)
1493 end 1729 end*)
1494 1730
1495 fun unmodCon env (c, loc) = 1731 fun unmodCon env (c, loc) =
1496 case c of 1732 case c of
1497 L'.CNamed n => 1733 L'.CNamed n =>
1498 (case E.lookupCNamed env n of 1734 (case E.lookupCNamed env n of
1833 checkCon env e' et' result; 2069 checkCon env e' et' result;
1834 ((p', e'), gs1 @ gs) 2070 ((p', e'), gs1 @ gs)
1835 end) 2071 end)
1836 gs1 pes 2072 gs1 pes
1837 in 2073 in
1838 if exhaustive (env, et, map #1 pes', loc) then 2074 case exhaustive (env, et, map #1 pes', loc) of
1839 () 2075 NONE => ()
1840 else 2076 | SOME p => expError env (Inexhaustive (loc, p));
1841 expError env (Inexhaustive loc);
1842 2077
1843 ((L'.ECase (e', pes', {disc = et, result = result}), loc), result, gs) 2078 ((L'.ECase (e', pes', {disc = et, result = result}), loc), result, gs)
1844 end 2079 end
1845 2080
1846 | L.ELet (eds, e) => 2081 | L.ELet (eds, e) =>
1849 val (e, t, gs2) = elabExp (env, denv) e 2084 val (e, t, gs2) = elabExp (env, denv) e
1850 in 2085 in
1851 ((L'.ELet (eds, e), loc), t, gs1 @ gs2) 2086 ((L'.ELet (eds, e), loc), t, gs1 @ gs2)
1852 end 2087 end
1853 in 2088 in
1854 (*prefaces "elabExp" [("e", SourcePrint.p_exp eAll), 2089 (*prefaces "/elabExp" [("e", SourcePrint.p_exp eAll)];*)
1855 ("t", PD.string (LargeReal.toString (Time.toReal (Time.- (Time.now (), befor)))))];*)
1856 r 2090 r
1857 end 2091 end
1858 2092
1859 and elabEdecl denv (dAll as (d, loc), (env, gs)) = 2093 and elabEdecl denv (dAll as (d, loc), (env, gs)) =
1860 let 2094 let