changeset 175:b2d752455182

Elaborating record patterns
author Adam Chlipala <adamc@hcoop.net>
date Thu, 31 Jul 2008 13:08:57 -0400
parents 7ee424760d2f
children 33d4a8eea484
files src/elab.sml src/elab_print.sml src/elaborate.sml src/lacweb.grm src/source_print.sml tests/rpat.lac
diffstat 6 files changed, 150 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/src/elab.sml	Thu Jul 31 11:28:55 2008 -0400
+++ b/src/elab.sml	Thu Jul 31 13:08:57 2008 -0400
@@ -80,6 +80,7 @@
        | PVar of string
        | PPrim of Prim.t
        | PCon of patCon * pat option
+       | PRecord of (string * pat) list * con option
 
 withtype pat = pat' located
 
--- a/src/elab_print.sml	Thu Jul 31 11:28:55 2008 -0400
+++ b/src/elab_print.sml	Thu Jul 31 13:08:57 2008 -0400
@@ -220,8 +220,19 @@
       | PCon (pc, SOME p) => parenIf par (box [p_patCon env pc,
                                                space,
                                                p_pat' true env p])
+      | PRecord (xps, flex) =>
+        let
+            val pps = map (fn (x, p) => box [string x, space, string "=", space, p_pat env p]) xps
+        in
+            box [string "{",
+                 p_list_sep (box [string ",", space]) (fn x => x)
+                 (case flex of
+                      NONE => pps
+                    | SOME _ => pps @ [string "..."]),
+                 string "}"]
+        end
 
-val p_pat = p_pat' false
+and p_pat x = p_pat' false x
 
 fun p_exp' par env (e, _) =
     case e of
--- a/src/elaborate.sml	Thu Jul 31 11:28:55 2008 -0400
+++ b/src/elaborate.sml	Thu Jul 31 13:08:57 2008 -0400
@@ -38,10 +38,14 @@
 open ElabPrint
 
 structure IM = IntBinaryMap
-structure SS = BinarySetFn(struct
-                           type ord_key = string
-                           val compare = String.compare
-                           end)
+
+structure SK = struct
+type ord_key = string
+val compare = String.compare
+end
+
+structure SS = BinarySetFn(SK)
+structure SM = BinaryMapFn(SK)
 
 fun elabExplicitness e =
     case e of
@@ -816,6 +820,7 @@
      | PatHasArg of ErrorMsg.span
      | PatHasNoArg of ErrorMsg.span
      | Inexhaustive of ErrorMsg.span
+     | DuplicatePatField of ErrorMsg.span * string
 
 fun expError env err =
     case err of
@@ -856,6 +861,8 @@
         ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument"
       | Inexhaustive loc =>
         ErrorMsg.errorAt loc "Inexhaustive 'case'"
+      | DuplicatePatField (loc, s) =>
+        ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern")
 
 fun checkCon (env, denv) e c1 c2 =
     unifyCons (env, denv) c1 c2
@@ -1021,13 +1028,45 @@
                        | SOME (_, to, dn) => pcon (L'.PConProj (n, ms, x), po, to, dn)
                  end)
 
-          | L.PRecord _ => raise Fail "Elaborate PRecord"
+          | L.PRecord (xps, flex) =>
+            let
+                val (xpts, (env, bound, _)) =
+                    ListUtil.foldlMap (fn ((x, p), (env, bound, fbound)) =>
+                                          let
+                                              val ((p', t), (env, bound)) = elabPat (p, (env, denv, bound))
+                                          in
+                                              if SS.member (fbound, x) then
+                                                  expError env (DuplicatePatField (loc, x))
+                                              else
+                                                  ();
+                                              ((x, p', t), (env, bound, SS.add (fbound, x)))
+                                          end)
+                    (env, bound, SS.empty) xps
+
+                val k = (L'.KType, loc)
+                val c = (L'.CRecord (k, map (fn (x, _, t) => ((L'.CName x, loc), t)) xpts), loc)
+                val (flex, c) =
+                    if flex then
+                        let
+                            val flex = cunif (loc, (L'.KRecord k, loc))
+                        in
+                            (SOME flex, (L'.CConcat (c, flex), loc))
+                        end
+                    else
+                        (NONE, c)
+            in
+                (((L'.PRecord (map (fn (x, p', _) => (x, p')) xpts, flex), loc),
+                  (L'.TRecord c, loc)),
+                 (env, bound))
+            end
+                                           
     end
 
 datatype coverage =
          Wild
        | None
        | Datatype of coverage IM.map
+       | Record of coverage SM.map list
 
 fun exhaustive (env, denv, t, ps) =
     let
@@ -1050,7 +1089,8 @@
               | L'.PPrim _ => None
               | L'.PCon (pc, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild))
               | L'.PCon (pc, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p))
-
+              | L'.PRecord (xps, _) => Record [foldl (fn ((x, p), fmap) =>
+                                                         SM.insert (fmap, x, coverage p)) SM.empty xps]
         fun merge (c1, c2) =
             case (c1, c2) of
                 (None, _) => c2
@@ -1061,12 +1101,84 @@
 
               | (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2))
 
+              | (Record fm1, Record fm2) => Record (fm1 @ fm2)
+
+              | _ => None
+
         fun combinedCoverage ps =
             case ps of
                 [] => raise Fail "Empty pattern list for coverage checking"
               | [p] => coverage p
               | p :: ps => merge (coverage p, combinedCoverage ps)
 
+        fun enumerateCases t =
+            let
+                fun dtype cons =
+                    ListUtil.mapConcat (fn (_, n, to) =>
+                                           case to of
+                                               NONE => [Datatype (IM.insert (IM.empty, n, Wild))]
+                                             | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c)))
+                                                         (enumerateCases t)) cons
+            in
+                case #1 (#1 (hnormCon (env, denv) t)) of
+                    L'.CNamed n =>
+                    (let
+                         val dt = E.lookupDatatype env n
+                         val cons = E.constructors dt
+                     in
+                         dtype cons
+                     end handle E.UnboundNamed _ => [Wild])
+                  | L'.TRecord c =>
+                    (case #1 (#1 (hnormCon (env, denv) c)) of
+                         L'.CRecord (_, xts) =>
+                         let
+                             val xts = map (fn (x, t) => (#1 (hnormCon (env, denv) x), t)) xts
+
+                             fun exponentiate fs =
+                                 case fs of
+                                     [] => [SM.empty]
+                                   | ((L'.CName x, _), t) :: rest =>
+                                     let
+                                         val this = enumerateCases t
+                                         val rest = exponentiate rest
+                                     in
+                                         ListUtil.mapConcat (fn fmap =>
+                                                                map (fn c => SM.insert (fmap, x, c)) this) rest
+                                     end
+                                   | _ => raise Fail "exponentiate: Not CName"
+                         in
+                             if List.exists (fn ((L'.CName _, _), _) => false
+                                              | (c, _) => true) xts then
+                                 [Wild]
+                             else
+                                 map (fn ls => Record [ls]) (exponentiate xts)
+                         end
+                       | _ => [Wild])
+                  | _ => [Wild]
+            end
+
+        fun coverageImp (c1, c2) =
+            case (c1, c2) of
+                (Wild, _) => true
+
+              | (Datatype cmap1, Datatype cmap2) =>
+                List.all (fn (n, c2) =>
+                             case IM.find (cmap1, n) of
+                                 NONE => false
+                               | SOME c1 => coverageImp (c1, c2)) (IM.listItemsi cmap2)
+
+              | (Record fmaps1, Record fmaps2) =>
+                List.all (fn fmap2 =>
+                             List.exists (fn fmap1 =>
+                                             List.all (fn (x, c2) =>
+                                                          case SM.find (fmap1, x) of
+                                                              NONE => true
+                                                            | SOME c1 => coverageImp (c1, c2))
+                                                      (SM.listItemsi fmap2))
+                                         fmaps1) fmaps2
+
+              | _ => false
+
         fun isTotal (c, t) =
             case c of
                 None => (false, [])
@@ -1109,6 +1221,7 @@
                       | L'.CError => (true, gs)
                       | _ => raise Fail "isTotal: Not a datatype"
                 end
+              | Record _ => (List.all (fn c2 => coverageImp (c, c2)) (enumerateCases t), [])
     in
         isTotal (combinedCoverage ps, t)
     end
--- a/src/lacweb.grm	Thu Jul 31 11:28:55 2008 -0400
+++ b/src/lacweb.grm	Thu Jul 31 13:08:57 2008 -0400
@@ -356,9 +356,9 @@
        | UNIT                           (PRecord ([], false), s (UNITleft, UNITright))
        | LBRACE rpat RBRACE             (PRecord rpat, s (LBRACEleft, RBRACEright))
 
-rpat   : STRING EQ pat                  ([(STRING, pat)], false)
+rpat   : CSYMBOL EQ pat                 ([(CSYMBOL, pat)], false)
        | DOTDOTDOT                      ([], true)
-       | STRING EQ pat COMMA rpat       ((STRING, pat) :: #1 rpat, #2 rpat)
+       | CSYMBOL EQ pat COMMA rpat      ((CSYMBOL, pat) :: #1 rpat, #2 rpat)
 
 rexp   :                                ([])
        | ident EQ eexp                  ([(ident, eexp)])
--- a/src/source_print.sml	Thu Jul 31 11:28:55 2008 -0400
+++ b/src/source_print.sml	Thu Jul 31 13:08:57 2008 -0400
@@ -173,14 +173,14 @@
                                                   p_pat' true p])
       | PRecord (xps, flex) =>
         let
-            val pps = map (fn (x, p) => box [string "x", space, string "=", space, p_pat p]) xps
+            val pps = map (fn (x, p) => box [string x, space, string "=", space, p_pat p]) xps
         in
             box [string "{",
                  p_list_sep (box [string ",", space]) (fn x => x)
                  (if flex then
-                      pps
+                      pps @ [string "..."]
                   else
-                      pps @ [string "..."]),
+                      pps),
                  string "}"]
         end
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/rpat.lac	Thu Jul 31 13:08:57 2008 -0400
@@ -0,0 +1,13 @@
+val f = fn x : {A : int} => case x of {A = _} => 0
+val f = fn x : {A : int} => case x of {A = _, ...} => 0
+val f = fn x : {A : int, B : int} => case x of {A = _, ...} => 0
+val f = fn x : {A : int, B : int} => case x of {A = 1, B = 2} => 0 | {A = _, ...} => 1
+
+datatype t = A | B
+
+val f = fn x => case x of {A = A, B = 2} => 0 | {A = A, ...} => 0 | {A = B, ...} => 0
+
+val f = fn x => case x of {A = {A = A, ...}, B = B} => 0
+        | {B = A, ...} => 1
+        | {A = {A = B, B = A}, B = B} => 2
+        | {A = {A = B, B = B}, B = B} => 3