diff src/elaborate.sml @ 172:021f5beb6f8d

Pattern match coverage checking
author Adam Chlipala <adamc@hcoop.net>
date Thu, 31 Jul 2008 10:31:30 -0400
parents c7a6e6dbc318
children 8221b95cc24c
line wrap: on
line diff
--- a/src/elaborate.sml	Thu Jul 31 10:06:27 2008 -0400
+++ b/src/elaborate.sml	Thu Jul 31 10:31:30 2008 -0400
@@ -37,6 +37,7 @@
 open Print
 open ElabPrint
 
+structure IM = IntBinaryMap
 structure SS = BinarySetFn(struct
                            type ord_key = string
                            val compare = String.compare
@@ -814,6 +815,7 @@
      | UnboundConstructor of ErrorMsg.span * string
      | PatHasArg of ErrorMsg.span
      | PatHasNoArg of ErrorMsg.span
+     | Inexhaustive of ErrorMsg.span
 
 fun expError env err =
     case err of
@@ -852,6 +854,8 @@
         ErrorMsg.errorAt loc "Constructor expects no argument but is used with argument"
       | PatHasNoArg loc =>
         ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument"
+      | Inexhaustive loc =>
+        ErrorMsg.errorAt loc "Inexhaustive 'case'"
 
 fun checkCon (env, denv) e c1 c2 =
     unifyCons (env, denv) c1 c2
@@ -1000,6 +1004,71 @@
           | L.PCon _ => raise Fail "uhoh"
     end
 
+datatype coverage =
+         Wild
+       | Datatype of coverage IM.map
+
+fun exhaustive (env, denv, t, ps) =
+    let
+        fun pcCoverage pc =
+            case pc of
+                L'.PConVar n => n
+              | _ => raise Fail "uh oh^2"
+
+        fun coverage (p, _) =
+            case p of
+                L'.PWild => Wild
+              | L'.PVar _ => Wild
+              | L'.PCon (pc, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild))
+              | L'.PCon (pc, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p))
+
+        fun merge (c1, c2) =
+            case (c1, c2) of
+                (Wild, _) => Wild
+              | (_, Wild) => Wild
+
+              | (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2))
+
+        fun combinedCoverage ps =
+            case ps of
+                [] => raise Fail "Empty pattern list for coverage checking"
+              | [p] => coverage p
+              | p :: ps => merge (coverage p, combinedCoverage ps)
+
+        fun isTotal (c, t) =
+            case c of
+                Wild => (true, nil)
+              | Datatype cm =>
+                let
+                    val ((t, _), gs) = hnormCon (env, denv) t
+                in
+                    case t of
+                        L'.CNamed n =>
+                        let
+                            val dt = E.lookupDatatype env n
+                            val cons = E.constructors dt
+                        in
+                            foldl (fn ((_, n, to), (total, gs)) =>
+                                      case IM.find (cm, n) of
+                                          NONE => (false, gs)
+                                        | SOME c' =>
+                                          case to of
+                                              NONE => (total, gs)
+                                            | SOME t' =>
+                                              let
+                                                  val (total, gs') = isTotal (c', t')
+                                              in
+                                                  (total, gs' @ gs)
+                                              end)
+                                  (true, gs) cons
+                        end
+                      | L'.CError => (true, gs)
+                      | _ => raise Fail "isTotal: Not a datatype"
+                end
+    in
+        isTotal (combinedCoverage ps, t)
+    end
+
 fun elabExp (env, denv) (eAll as (e, loc)) =
     let
         
@@ -1227,8 +1296,15 @@
                                          ((p', e'), gs1 @ gs2 @ gs3 @ gs)
                                      end)
                                  gs1 pes
+
+                val (total, gs') = exhaustive (env, denv, et, map #1 pes')
             in
-                ((L'.ECase (e', pes', result), loc), result, gs)
+                if total then
+                    ()
+                else
+                    expError env (Inexhaustive loc);
+
+                ((L'.ECase (e', pes', result), loc), result, gs' @ gs)
             end
     end