changeset 819:cb30dd2ba353

Switch to Maranget's pattern exhaustiveness algorithm
author Adam Chlipala <adamc@hcoop.net>
date Sat, 23 May 2009 09:45:02 -0400
parents 066493f7f008
children 91f465ded07e
files src/elab_err.sig src/elab_err.sml src/elaborate.sml
diffstat 3 files changed, 252 insertions(+), 17 deletions(-) [+]
line wrap: on
line diff
--- a/src/elab_err.sig	Thu May 21 11:45:04 2009 -0400
+++ b/src/elab_err.sig	Sat May 23 09:45:02 2009 -0400
@@ -71,7 +71,7 @@
            | UnboundConstructor of ErrorMsg.span * string list * string
            | PatHasArg of ErrorMsg.span
            | PatHasNoArg of ErrorMsg.span
-           | Inexhaustive of ErrorMsg.span
+           | Inexhaustive of ErrorMsg.span * Elab.pat
            | DuplicatePatField of ErrorMsg.span * string
            | Unresolvable of ErrorMsg.span * Elab.con
            | OutOfContext of ErrorMsg.span * (Elab.exp * Elab.con) option
--- a/src/elab_err.sml	Thu May 21 11:45:04 2009 -0400
+++ b/src/elab_err.sml	Sat May 23 09:45:02 2009 -0400
@@ -161,7 +161,7 @@
      | UnboundConstructor of ErrorMsg.span * string list * string
      | PatHasArg of ErrorMsg.span
      | PatHasNoArg of ErrorMsg.span
-     | Inexhaustive of ErrorMsg.span
+     | Inexhaustive of ErrorMsg.span * pat
      | DuplicatePatField of ErrorMsg.span * string
      | Unresolvable of ErrorMsg.span * con
      | OutOfContext of ErrorMsg.span * (exp * con) option
@@ -207,8 +207,9 @@
         ErrorMsg.errorAt loc "Constructor expects no argument but is used with argument"
       | PatHasNoArg loc =>
         ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument"
-      | Inexhaustive loc =>
-        ErrorMsg.errorAt loc "Inexhaustive 'case'"
+      | Inexhaustive (loc, p) =>
+        (ErrorMsg.errorAt loc "Inexhaustive 'case'";
+         eprefaces' [("Missed case", p_pat env p)])
       | DuplicatePatField (loc, s) =>
         ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern")
       | OutOfContext (loc, co) =>
--- a/src/elaborate.sml	Thu May 21 11:45:04 2009 -0400
+++ b/src/elaborate.sml	Sat May 23 09:45:02 2009 -0400
@@ -38,6 +38,7 @@
  open ElabPrint
  open ElabErr
 
+ structure IS = IntBinarySet
  structure IM = IntBinaryMap
 
  structure SK = struct
@@ -1291,7 +1292,238 @@
                                            
     end
 
-datatype coverage =
+(* This exhaustiveness checking follows Luc Maranget's paper "Warnings for pattern matching." *)
+fun exhaustive (env, t, ps, loc) =
+    let
+        fun fail n = raise Fail ("Elaborate.exhaustive: Impossible " ^ Int.toString n)
+
+        fun patConNum pc =
+            case pc of
+                L'.PConVar n => n
+              | L'.PConProj (m1, ms, x) =>
+                let
+                    val (str, sgn) = E.chaseMpath env (m1, ms)
+                in
+                    case E.projectConstructor env {str = str, sgn = sgn, field = x} of
+                        NONE => raise Fail "exhaustive: Can't project datatype"
+                      | SOME (_, n, _, _, _) => n
+                end
+
+        fun nameOfNum (t, n) =
+            case t of
+                L'.CModProj (m1, ms, x) =>
+                let
+                    val (str, sgn) = E.chaseMpath env (m1, ms)
+                in
+                    case E.projectDatatype env {str = str, sgn = sgn, field = x} of
+                        NONE => raise Fail "exhaustive: Can't project datatype"
+                      | SOME (_, cons) =>
+                        case ListUtil.search (fn (name, n', _) =>
+                                                 if n' = n then
+                                                     SOME name
+                                                 else
+                                                     NONE) cons of
+                            NONE => fail 9
+                          | SOME name => L'.PConProj (m1, ms, name)
+                end
+              | _ => L'.PConVar n
+
+        fun S (args, c, P) =
+            List.mapPartial
+            (fn [] => fail 1
+              | p1 :: ps =>
+                let
+                    val loc = #2 p1
+
+                    fun wild () =
+                        SOME (map (fn _ => (L'.PWild, loc)) args @ ps)
+                in
+                    case #1 p1 of
+                        L'.PPrim _ => NONE
+                      | L'.PCon (_, c', _, NONE) =>
+                        if patConNum c' = c then
+                            SOME ps
+                        else
+                            NONE
+                      | L'.PCon (_, c', _, SOME p) =>
+                        if patConNum c' = c then
+                            SOME (p :: ps)
+                        else
+                            NONE
+                      | L'.PRecord xpts =>
+                        SOME (map (fn x =>
+                                      case ListUtil.search (fn (x', p, _) =>
+                                                               if x = x' then
+                                                                   SOME p
+                                                               else
+                                                                   NONE) xpts of
+                                          NONE => (L'.PWild, loc)
+                                        | SOME p => p) args @ ps)
+                      | L'.PWild => wild ()
+                      | L'.PVar _ => wild ()
+                end)
+            P
+
+        fun D P =
+            List.mapPartial
+            (fn [] => fail 2
+              | (p1, _) :: ps =>
+                case p1 of
+                    L'.PWild => SOME ps
+                  | L'.PVar _ => SOME ps
+                  | L'.PPrim _ => NONE
+                  | L'.PCon _ => NONE
+                  | L'.PRecord _ => NONE)
+            P
+
+        fun I (P, q) =
+            (*(prefaces "I" [("P", p_list (fn P' => box [PD.string "[", p_list (p_pat env) P', PD.string "]"]) P),
+                           ("q", p_list (p_con env) q)];*)
+            case q of
+                [] => (case P of
+                           [] => SOME []
+                         | _ => NONE)
+              | q1 :: qs =>
+                let
+                    val loc = #2 q1
+
+                    fun unapp (t, acc) =
+                        case t of
+                            L'.CApp ((t, _), arg) => unapp (t, arg :: acc)
+                          | _ => (t, rev acc)
+
+                    val (t1, args) = unapp (#1 (hnormCon env q1), [])
+                    fun doSub t = foldl (fn (arg, t) => subConInCon (0, arg) t) t args
+
+                    fun dtype (dtO, names) =
+                        let
+                            val nameSet = IS.addList (IS.empty, names)
+                            val nameSet = foldl (fn (ps, nameSet) =>
+                                                    case ps of
+                                                        [] => fail 4
+                                                      | (L'.PCon (_, pc, _, _), _) :: _ =>
+                                                        (IS.delete (nameSet, patConNum pc)
+                                                         handle NotFound => nameSet)
+                                                      | _ => nameSet)
+                                                nameSet P
+                        in
+                            nameSet
+                        end
+
+                    fun default () = (NONE, IS.singleton 0, [])
+
+                    val (dtO, unused, cons) =
+                        case t1 of
+                            L'.CNamed n =>
+                            let
+                                val dt = E.lookupDatatype env n
+                                val cons = E.constructors dt
+                            in
+                                (SOME dt,
+                                 dtype (SOME dt, map #2 cons),
+                                 map (fn (_, n, co) =>
+                                         (n,
+                                          case co of
+                                              NONE => []
+                                            | SOME t => [("", doSub t)])) cons)
+                            end
+                          | L'.CModProj (m1, ms, x) =>
+                            let
+                                val (str, sgn) = E.chaseMpath env (m1, ms)
+                            in
+                                case E.projectDatatype env {str = str, sgn = sgn, field = x} of
+                                    NONE => default ()
+                                  | SOME (_, cons) =>
+                                    (NONE,
+                                     dtype (NONE, map #2 cons),
+                                     map (fn (s, _, co) =>
+                                             (patConNum (L'.PConProj (m1, ms, s)),
+                                              case co of
+                                                  NONE => []
+                                                | SOME t => [("", doSub t)])) cons)
+                            end
+                          | L'.TRecord (L'.CRecord (_, xts), _) =>
+                            let
+                                val xts = map (fn ((L'.CName x, _), co) => SOME (x, co)
+                                                | _ => NONE) xts
+                            in
+                                if List.all Option.isSome xts then
+                                    let
+                                        val xts = List.mapPartial (fn x => x) xts
+                                        val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) =>
+                                                                         String.compare (x1, x2) = GREATER) xts
+                                    in
+                                        (NONE, IS.empty, [(0, xts)])
+                                    end
+                                else
+                                    default ()
+                            end
+                          | _ => default ()
+                in
+                    if IS.isEmpty unused then
+                        let
+                            fun recurse cons =
+                                case cons of
+                                    [] => NONE
+                                  | (name, args) :: cons =>
+                                    case I (S (map #1 args, name, P),
+                                            map #2 args @ qs) of
+                                        NONE => recurse cons
+                                      | SOME ps =>
+                                        let
+                                            val nargs = length args
+                                            val argPs = List.take (ps, nargs)
+                                            val restPs = List.drop (ps, nargs)
+
+                                            val p = case name of
+                                                        0 => L'.PRecord (ListPair.map
+                                                                                  (fn ((name, t), p) => (name, p, t))
+                                                                                  (args, argPs))
+                                                      | _  => L'.PCon (L'.Default, nameOfNum (t1, name), [],
+                                                                       case argPs of
+                                                                           [] => NONE
+                                                                         | [p] => SOME p
+                                                                         | _ => fail 3)
+                                        in
+                                            SOME ((p, loc) :: restPs)
+                                        end
+                        in
+                            recurse cons
+                        end
+                    else
+                        case I (D P, qs) of
+                            NONE => NONE
+                          | SOME ps =>
+                            let
+                                val p = case cons of
+                                            [] => L'.PWild
+                                          | (0, _) :: _ => L'.PWild
+                                          | _ =>
+                                            case IS.find (fn _ => true) unused of
+                                                NONE => fail 6
+                                              | SOME name =>
+                                                case ListUtil.search (fn (name', args) =>
+                                                                         if name = name' then
+                                                                             SOME (name', args)
+                                                                         else
+                                                                             NONE) cons of
+                                                    SOME (n, []) =>
+                                                    L'.PCon (L'.Default, nameOfNum (t1, n), [], NONE)
+                                                  | SOME (n, [_]) =>
+                                                    L'.PCon (L'.Default, nameOfNum (t1, n), [], SOME (L'.PWild, loc))
+                                                  | _ => fail 7
+                            in
+                                SOME ((p, loc) :: ps)
+                            end
+                end
+    in
+        case I (map (fn x => [x]) ps, [t]) of
+            NONE => NONE
+          | SOME [p] => SOME p
+          | _ => fail 7
+    end
+
+(*datatype coverage =
          Wild
        | None
        | Datatype of coverage IM.map
@@ -1360,16 +1592,16 @@
               | p :: ps => merge (coverage p, combinedCoverage ps)
 
         fun enumerateCases depth t =
-            if depth = 0 then
+            (TextIO.print "enum'\n"; if depth <= 0 then
                 [Wild]
             else
                 let
-                    fun dtype cons =
+                    val dtype =
                         ListUtil.mapConcat (fn (_, n, to) =>
                                                case to of
                                                    NONE => [Datatype (IM.insert (IM.empty, n, Wild))]
                                                  | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c)))
-                                                                 (enumerateCases (depth-1) t)) cons
+                                                                 (enumerateCases (depth-1) t))
                 in
                     case #1 (hnormCon env t) of
                         L'.CNamed n =>
@@ -1393,8 +1625,11 @@
                                              val this = enumerateCases depth t
                                              val rest = exponentiate rest
                                          in
+                                             TextIO.print ("Before (" ^ Int.toString (length this)
+                                                           ^ ", " ^ Int.toString (length rest) ^ ")\n");
                                              ListUtil.mapConcat (fn fmap =>
                                                                     map (fn c => SM.insert (fmap, x, c)) this) rest
+                                             before TextIO.print "After\n"
                                          end
                                        | _ => raise Fail "exponentiate: Not CName"
                              in
@@ -1406,7 +1641,7 @@
                              end
                            | _ => [Wild])
                       | _ => [Wild]
-                end
+                end before TextIO.print "/enum'\n")
 
         fun coverageImp (c1, c2) =
             let
@@ -1487,10 +1722,11 @@
                                                     ("c", p_con env (c, ErrorMsg.dummySpan))];
                          raise Fail "isTotal: Not a datatype")
                 end
-              | Record _ => List.all (fn c2 => coverageImp (c, c2)) (enumerateCases depth t)
+              | Record _ => List.all (fn c2 => coverageImp (c, c2))
+                                     (TextIO.print "enum\n"; enumerateCases depth t before TextIO.print "/enum\n")
     in
         isTotal (combinedCoverage ps, t)
-    end
+    end*)
 
 fun unmodCon env (c, loc) =
     case c of
@@ -1835,10 +2071,9 @@
                                      end)
                                  gs1 pes
             in
-                if exhaustive (env, et, map #1 pes', loc) then
-                    ()
-                else
-                    expError env (Inexhaustive loc);
+                case exhaustive (env, et, map #1 pes', loc) of
+                    NONE => ()
+                  | SOME p => expError env (Inexhaustive (loc, p));
 
                 ((L'.ECase (e', pes', {disc = et, result = result}), loc), result, gs)
             end
@@ -1851,8 +2086,7 @@
                 ((L'.ELet (eds, e), loc), t, gs1 @ gs2)
             end
     in
-        (*prefaces "elabExp" [("e", SourcePrint.p_exp eAll),
-                            ("t", PD.string (LargeReal.toString (Time.toReal (Time.- (Time.now (), befor)))))];*)
+        (*prefaces "/elabExp" [("e", SourcePrint.p_exp eAll)];*)
         r
     end