changeset 830:d07980bf1444

Defer pattern-matching exhaustiveness checks and normalize pattern types more thoroughly
author Adam Chlipala <adamc@hcoop.net>
date Sat, 30 May 2009 14:44:29 -0400
parents 20fe00fd81da
children 5e1a4b12c83a
files lib/ur/list.ur lib/ur/list.urs src/elaborate.sml
diffstat 3 files changed, 67 insertions(+), 230 deletions(-) [+]
line wrap: on
line diff
--- a/lib/ur/list.ur	Sat May 30 13:29:00 2009 -0400
+++ b/lib/ur/list.ur	Sat May 30 14:44:29 2009 -0400
@@ -74,6 +74,19 @@
         mapM' []
     end
 
+fun mapXM [m ::: (Type -> Type)] (_ : monad m) [a] [ctx ::: {Unit}] f =
+    let
+        fun mapXM' ls =
+            case ls of
+                [] => return <xml/>
+              | x :: ls =>
+                this <- f x;
+                rest <- mapXM' ls;
+                return <xml>{this}{rest}</xml>
+    in
+        mapXM'
+    end
+
 fun filter [a] f =
     let
         fun fil acc ls =
--- a/lib/ur/list.urs	Sat May 30 13:29:00 2009 -0400
+++ b/lib/ur/list.urs	Sat May 30 14:44:29 2009 -0400
@@ -15,7 +15,10 @@
 val mapX : a ::: Type -> ctx ::: {Unit} -> (a -> xml ctx [] []) -> t a -> xml ctx [] []
 
 val mapM : m ::: (Type -> Type) -> monad m -> a ::: Type -> b ::: Type
-           -> (a -> m b) -> list a -> m (list b)
+           -> (a -> m b) -> t a -> m (t b)
+
+val mapXM : m ::: (Type -> Type) -> monad m -> a ::: Type -> ctx ::: {Unit}
+            -> (a -> m (xml ctx [] [])) -> t a -> m (xml ctx [] [])
 
 val filter : a ::: Type -> (a -> bool) -> t a -> t a
 
--- a/src/elaborate.sml	Sat May 30 13:29:00 2009 -0400
+++ b/src/elaborate.sml	Sat May 30 14:44:29 2009 -0400
@@ -625,6 +625,8 @@
  val mayDelay = ref false
  val delayedUnifs = ref ([] : (ErrorMsg.span * E.env * L'.kind * record_summary * record_summary) list)
 
+ val delayedExhaustives = ref ([] : (E.env * L'.con * L'.pat list * ErrorMsg.span) list)
+
  fun unifyRecordCons env (loc, c1, c2) =
      let
          fun rkindof c =
@@ -1398,11 +1400,12 @@
                     val loc = #2 q1
 
                     fun unapp (t, acc) =
-                        case t of
-                            L'.CApp ((t, _), arg) => unapp (t, arg :: acc)
+                        case #1 t of
+                            L'.CApp (t, arg) => unapp (t, arg :: acc)
                           | _ => (t, rev acc)
 
-                    val (t1, args) = unapp (#1 (hnormCon env q1), [])
+                    val (t1, args) = unapp (hnormCon env q1, [])
+                    val t1 = hnormCon env t1
                     fun doSub t = foldl (fn (arg, t) => subConInCon (0, arg) t) t args
 
                     fun dtype (dtO, names) =
@@ -1423,7 +1426,7 @@
                     fun default () = (NONE, IS.singleton 0, [])
 
                     val (dtO, unused, cons) =
-                        case t1 of
+                        case #1 t1 of
                             L'.CNamed n =>
                             let
                                 val dt = E.lookupDatatype env n
@@ -1452,22 +1455,25 @@
                                                   NONE => []
                                                 | SOME t => [("", doSub t)])) cons)
                             end
-                          | L'.TRecord (L'.CRecord (_, xts), _) =>
-                            let
-                                val xts = map (fn ((L'.CName x, _), co) => SOME (x, co)
-                                                | _ => NONE) xts
-                            in
-                                if List.all Option.isSome xts then
-                                    let
-                                        val xts = List.mapPartial (fn x => x) xts
-                                        val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) =>
-                                                                         String.compare (x1, x2) = GREATER) xts
-                                    in
-                                        (NONE, IS.empty, [(0, xts)])
-                                    end
-                                else
-                                    default ()
-                            end
+                          | L'.TRecord t =>
+                            (case #1 (hnormCon env t) of
+                                 L'.CRecord (_, xts) =>
+                                 let
+                                     val xts = map (fn ((L'.CName x, _), co) => SOME (x, co)
+                                                     | _ => NONE) xts
+                                 in
+                                     if List.all Option.isSome xts then
+                                         let
+                                             val xts = List.mapPartial (fn x => x) xts
+                                             val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) =>
+                                                                              String.compare (x1, x2) = GREATER) xts
+                                         in
+                                             (NONE, IS.empty, [(0, xts)])
+                                         end
+                                     else
+                                         default ()
+                                 end
+                               | _ => default ())
                           | _ => default ()
                 in
                     if IS.isEmpty unused then
@@ -1489,7 +1495,7 @@
                                                         0 => L'.PRecord (ListPair.map
                                                                                   (fn ((name, t), p) => (name, p, t))
                                                                                   (args, argPs))
-                                                      | _  => L'.PCon (L'.Default, nameOfNum (t1, name), [],
+                                                      | _  => L'.PCon (L'.Default, nameOfNum (#1 t1, name), [],
                                                                        case argPs of
                                                                            [] => NONE
                                                                          | [p] => SOME p
@@ -1518,9 +1524,9 @@
                                                                          else
                                                                              NONE) cons of
                                                     SOME (n, []) =>
-                                                    L'.PCon (L'.Default, nameOfNum (t1, n), [], NONE)
+                                                    L'.PCon (L'.Default, nameOfNum (#1 t1, n), [], NONE)
                                                   | SOME (n, [_]) =>
-                                                    L'.PCon (L'.Default, nameOfNum (t1, n), [], SOME (L'.PWild, loc))
+                                                    L'.PCon (L'.Default, nameOfNum (#1 t1, n), [], SOME (L'.PWild, loc))
                                                   | _ => fail 7
                             in
                                 SOME ((p, loc) :: ps)
@@ -1533,211 +1539,6 @@
           | _ => fail 7
     end
 
-(*datatype coverage =
-         Wild
-       | None
-       | Datatype of coverage IM.map
-       | Record of coverage SM.map list
-
-fun c2s c =
-    case c of
-        Wild => "Wild"
-      | None => "None"
-      | Datatype _ => "Datatype"
-      | Record _ => "Record"
-
-fun exhaustive (env, t, ps, loc) =
-    let
-        fun depth (p, _) =
-            case p of
-                L'.PWild => 0
-              | L'.PVar _ => 0
-              | L'.PPrim _ => 0
-              | L'.PCon (_, _, _, NONE) => 1
-              | L'.PCon (_, _, _, SOME p) => 1 + depth p
-              | L'.PRecord xps => foldl (fn ((_, p, _), n) => Int.max (depth p, n)) 0 xps
-
-        val depth = 1 + foldl (fn (p, n) => Int.max (depth p, n)) 0 ps
-
-        fun pcCoverage pc =
-            case pc of
-                L'.PConVar n => n
-              | L'.PConProj (m1, ms, x) =>
-                let
-                    val (str, sgn) = E.chaseMpath env (m1, ms)
-                in
-                    case E.projectConstructor env {str = str, sgn = sgn, field = x} of
-                        NONE => raise Fail "exhaustive: Can't project constructor"
-                      | SOME (_, n, _, _, _) => n
-                end
-
-        fun coverage (p, _) =
-            case p of
-                L'.PWild => Wild
-              | L'.PVar _ => Wild
-              | L'.PPrim _ => None
-              | L'.PCon (_, pc, _, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild))
-              | L'.PCon (_, pc, _, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p))
-              | L'.PRecord xps => Record [foldl (fn ((x, p, _), fmap) =>
-                                                    SM.insert (fmap, x, coverage p)) SM.empty xps]
-
-        fun merge (c1, c2) =
-            case (c1, c2) of
-                (None, _) => c2
-              | (_, None) => c1
-                
-              | (Wild, _) => Wild
-              | (_, Wild) => Wild
-
-              | (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2))
-
-              | (Record fm1, Record fm2) => Record (fm1 @ fm2)
-
-              | _ => None
-
-        fun combinedCoverage ps =
-            case ps of
-                [] => raise Fail "Empty pattern list for coverage checking"
-              | [p] => coverage p
-              | p :: ps => merge (coverage p, combinedCoverage ps)
-
-        fun enumerateCases depth t =
-            (TextIO.print "enum'\n"; if depth <= 0 then
-                [Wild]
-            else
-                let
-                    val dtype =
-                        ListUtil.mapConcat (fn (_, n, to) =>
-                                               case to of
-                                                   NONE => [Datatype (IM.insert (IM.empty, n, Wild))]
-                                                 | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c)))
-                                                                 (enumerateCases (depth-1) t))
-                in
-                    case #1 (hnormCon env t) of
-                        L'.CNamed n =>
-                        (let
-                             val dt = E.lookupDatatype env n
-                             val cons = E.constructors dt
-                         in
-                             dtype cons
-                         end handle E.UnboundNamed _ => [Wild])
-                      | L'.TRecord c =>
-                        (case #1 (hnormCon env c) of
-                             L'.CRecord (_, xts) =>
-                             let
-                                 val xts = map (fn (x, t) => (hnormCon env x, t)) xts
-
-                                 fun exponentiate fs =
-                                     case fs of
-                                         [] => [SM.empty]
-                                       | ((L'.CName x, _), t) :: rest =>
-                                         let
-                                             val this = enumerateCases depth t
-                                             val rest = exponentiate rest
-                                         in
-                                             TextIO.print ("Before (" ^ Int.toString (length this)
-                                                           ^ ", " ^ Int.toString (length rest) ^ ")\n");
-                                             ListUtil.mapConcat (fn fmap =>
-                                                                    map (fn c => SM.insert (fmap, x, c)) this) rest
-                                             before TextIO.print "After\n"
-                                         end
-                                       | _ => raise Fail "exponentiate: Not CName"
-                             in
-                                 if List.exists (fn ((L'.CName _, _), _) => false
-                                                  | (c, _) => true) xts then
-                                     [Wild]
-                                 else
-                                     map (fn ls => Record [ls]) (exponentiate xts)
-                             end
-                           | _ => [Wild])
-                      | _ => [Wild]
-                end before TextIO.print "/enum'\n")
-
-        fun coverageImp (c1, c2) =
-            let
-                val r =
-                    case (c1, c2) of
-                        (Wild, _) => true
-
-                      | (Datatype cmap1, Datatype cmap2) =>
-                        List.all (fn (n, c2) =>
-                                     case IM.find (cmap1, n) of
-                                         NONE => false
-                                       | SOME c1 => coverageImp (c1, c2)) (IM.listItemsi cmap2)
-                      | (Datatype cmap1, Wild) =>
-                        List.all (fn (n, c1) => coverageImp (c1, Wild)) (IM.listItemsi cmap1)
-
-                      | (Record fmaps1, Record fmaps2) =>
-                        List.all (fn fmap2 =>
-                                     List.exists (fn fmap1 =>
-                                                     List.all (fn (x, c2) =>
-                                                                  case SM.find (fmap1, x) of
-                                                                      NONE => true
-                                                                    | SOME c1 => coverageImp (c1, c2))
-                                                              (SM.listItemsi fmap2))
-                                                 fmaps1) fmaps2
-
-                      | (Record fmaps1, Wild) =>
-                        List.exists (fn fmap1 =>
-                                        List.all (fn (x, c1) => coverageImp (c1, Wild))
-                                        (SM.listItemsi fmap1)) fmaps1
-
-                      | _ => false
-            in
-                (*TextIO.print ("coverageImp(" ^ c2s c1 ^ ", " ^ c2s c2 ^ ") = " ^ Bool.toString r ^ "\n");*)
-                r
-            end
-
-        fun isTotal (c, t) =
-            case c of
-                None => false
-              | Wild => true
-              | Datatype cm =>
-                let
-                    val (t, _) = hnormCon env t
-
-                    val dtype =
-                        List.all (fn (_, n, to) =>
-                                     case IM.find (cm, n) of
-                                         NONE => false
-                                       | SOME c' =>
-                                         case to of
-                                             NONE => true
-                                           | SOME t' => isTotal (c', t'))
-
-                    fun unapp t =
-                        case t of
-                            L'.CApp ((t, _), _) => unapp t
-                          | _ => t
-                in
-                    case unapp t of
-                        L'.CNamed n =>
-                        let
-                            val dt = E.lookupDatatype env n
-                            val cons = E.constructors dt
-                        in
-                            dtype cons
-                        end
-                      | L'.CModProj (m1, ms, x) =>
-                        let
-                            val (str, sgn) = E.chaseMpath env (m1, ms)
-                        in
-                            case E.projectDatatype env {str = str, sgn = sgn, field = x} of
-                                NONE => raise Fail "isTotal: Can't project datatype"
-                              | SOME (_, cons) => dtype cons
-                        end
-                      | L'.CError => true
-                      | c =>
-                        (prefaces "Not a datatype" [("loc", PD.string (ErrorMsg.spanToString loc)),
-                                                    ("c", p_con env (c, ErrorMsg.dummySpan))];
-                         raise Fail "isTotal: Not a datatype")
-                end
-              | Record _ => List.all (fn c2 => coverageImp (c, c2))
-                                     (TextIO.print "enum\n"; enumerateCases depth t before TextIO.print "/enum\n")
-    in
-        isTotal (combinedCoverage ps, t)
-    end*)
-
 fun unmodCon env (c, loc) =
     case c of
         L'.CNamed n =>
@@ -2083,7 +1884,10 @@
             in
                 case exhaustive (env, et, map #1 pes', loc) of
                     NONE => ()
-                  | SOME p => expError env (Inexhaustive (loc, p));
+                  | SOME p => if !mayDelay then
+                                  delayedExhaustives := (env, et, map #1 pes', loc) :: !delayedExhaustives
+                              else
+                                  expError env (Inexhaustive (loc, p));
 
                 ((L'.ECase (e', pes', {disc = et, result = result}), loc), result, gs)
             end
@@ -2113,6 +1917,13 @@
 
                     val pt = normClassConstraint env pt
                 in
+                    case exhaustive (env, et, [p'], loc) of
+                        NONE => ()
+                      | SOME p => if !mayDelay then
+                                      delayedExhaustives := (env, et, [p'], loc) :: !delayedExhaustives
+                                  else
+                                      expError env (Inexhaustive (loc, p));
+
                     ((L'.EDVal (p', pt, e'), loc), (env', gs1 @ gs))
                 end
               | L.EDValRec vis =>
@@ -3956,6 +3767,7 @@
     let
         val () = mayDelay := true
         val () = delayedUnifs := []
+        val () = delayedExhaustives := []
 
         val (sgn, gs) = elabSgn (env, D.empty) (L.SgnConst basis, ErrorMsg.dummySpan)
         val () = case gs of
@@ -4153,6 +3965,15 @@
         else
             app (fn f => f ()) (!checks);
 
+        if ErrorMsg.anyErrors () then
+            ()
+        else
+            app (fn all as (_, _, _, loc) =>
+                    case exhaustive all of
+                        NONE => ()
+                      | SOME p => expError env (Inexhaustive (loc, p)))
+                (!delayedExhaustives);
+
         (*preface ("file", p_file env' file);*)
 
         (L'.DFfiStr ("Basis", basis_n, sgn), ErrorMsg.dummySpan)