changeset 171:c7a6e6dbc318

Elaborating some basic pattern matching
author Adam Chlipala <adamc@hcoop.net>
date Thu, 31 Jul 2008 10:06:27 -0400
parents a158f8c5aa55
children 021f5beb6f8d
files src/elab.sml src/elab_env.sig src/elab_env.sml src/elab_print.sig src/elab_print.sml src/elab_util.sml src/elaborate.sml src/explify.sml src/source_print.sml tests/case.lac
diffstat 10 files changed, 189 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/src/elab.sml	Tue Jul 29 16:38:15 2008 -0400
+++ b/src/elab.sml	Thu Jul 31 10:06:27 2008 -0400
@@ -71,6 +71,17 @@
 
 withtype con = con' located
 
+datatype patCon =
+         PConVar of int
+       | PConProj of int * string list * string
+
+datatype pat' =
+         PWild
+       | PVar of string
+       | PCon of patCon * pat option
+
+withtype pat = pat' located
+
 datatype exp' =
          EPrim of Prim.t
        | ERel of int
@@ -86,6 +97,8 @@
        | ECut of exp * con * { field : con, rest : con }
        | EFold of kind
 
+       | ECase of exp * (pat * exp) list * con
+
        | EError
 
 withtype exp = exp' located
--- a/src/elab_env.sig	Tue Jul 29 16:38:15 2008 -0400
+++ b/src/elab_env.sig	Thu Jul 31 10:06:27 2008 -0400
@@ -54,9 +54,11 @@
     val pushDatatype : env -> int -> (string * int * Elab.con option) list -> env
     type datatyp
     val lookupDatatype : env -> int -> datatyp
-    val lookupConstructor : datatyp -> int -> string * Elab.con option
+    val lookupDatatypeConstructor : datatyp -> int -> string * Elab.con option
     val constructors : datatyp -> (string * int * Elab.con option) list
 
+    val lookupConstructor : env -> string -> (int * Elab.con option * int) option
+
     val pushERel : env -> string -> Elab.con -> env
     val lookupERel : env -> int -> string * Elab.con
 
--- a/src/elab_env.sml	Tue Jul 29 16:38:15 2008 -0400
+++ b/src/elab_env.sml	Thu Jul 31 10:06:27 2008 -0400
@@ -81,6 +81,7 @@
      namedC : (string * kind * con option) IM.map,
 
      datatypes : datatyp IM.map,
+     constructors : (int * con option * int) SM.map,
 
      renameE : con var' SM.map,
      relE : (string * con) list,
@@ -109,6 +110,7 @@
     namedC = IM.empty,
 
     datatypes = IM.empty,
+    constructors = SM.empty,
 
     renameE = SM.empty,
     relE = [],
@@ -131,6 +133,7 @@
          namedC = IM.map (fn (x, k, co) => (x, k, Option.map lift co)) (#namedC env),
 
          datatypes = #datatypes env,
+         constructors = #constructors env,
 
          renameE = #renameE env,
          relE = map (fn (x, c) => (x, lift c)) (#relE env),
@@ -154,6 +157,7 @@
      namedC = IM.insert (#namedC env, n, (x, k, co)),
 
      datatypes = #datatypes env,
+     constructors = #constructors env,
 
      renameE = #renameE env,
      relE = #relE env,
@@ -192,6 +196,9 @@
      datatypes = IM.insert (#datatypes env, n,
                             foldl (fn ((x, n, to), cons) =>
                                       IM.insert (cons, n, (x, to))) IM.empty xncs),
+     constructors = foldl (fn ((x, n', to), cmap) =>
+                              SM.insert (cmap, x, (n', to, n)))
+                          (#constructors env) xncs,
 
      renameE = #renameE env,
      relE = #relE env,
@@ -208,11 +215,13 @@
         NONE => raise UnboundNamed n
       | SOME x => x
 
-fun lookupConstructor dt n =
+fun lookupDatatypeConstructor dt n =
     case IM.find (dt, n) of
         NONE => raise UnboundNamed n
       | SOME x => x
 
+fun lookupConstructor (env : env) s = SM.find (#constructors env, s)
+
 fun constructors dt = IM.foldri (fn (n, (x, to), ls) => (x, n, to) :: ls) [] dt
 
 fun pushERel (env : env) x t =
@@ -225,6 +234,7 @@
          namedC = #namedC env,
 
          datatypes = #datatypes env,
+         constructors = #constructors env,
 
          renameE = SM.insert (renameE, x, Rel' (0, t)),
          relE = (x, t) :: #relE env,
@@ -247,6 +257,7 @@
      namedC = #namedC env,
 
      datatypes = #datatypes env,
+     constructors = #constructors env,
 
      renameE = SM.insert (#renameE env, x, Named' (n, t)),
      relE = #relE env,
@@ -283,6 +294,7 @@
      namedC = #namedC env,
 
      datatypes = #datatypes env,
+     constructors = #constructors env,
 
      renameE = #renameE env,
      relE = #relE env,
@@ -315,6 +327,7 @@
      namedC = #namedC env,
 
      datatypes = #datatypes env,
+     constructors = #constructors env,
 
      renameE = #renameE env,
      relE = #relE env,
--- a/src/elab_print.sig	Tue Jul 29 16:38:15 2008 -0400
+++ b/src/elab_print.sig	Thu Jul 31 10:06:27 2008 -0400
@@ -31,6 +31,7 @@
     val p_kind : Elab.kind Print.printer
     val p_explicitness : Elab.explicitness Print.printer
     val p_con : ElabEnv.env -> Elab.con Print.printer
+    val p_pat : ElabEnv.env -> Elab.pat Print.printer
     val p_exp : ElabEnv.env -> Elab.exp Print.printer
     val p_decl : ElabEnv.env -> Elab.decl Print.printer
     val p_sgn_item : ElabEnv.env -> Elab.sgn_item Print.printer
--- a/src/elab_print.sml	Tue Jul 29 16:38:15 2008 -0400
+++ b/src/elab_print.sml	Thu Jul 31 10:06:27 2008 -0400
@@ -190,6 +190,38 @@
         CName s => string s
       | _ => p_con env all
 
+fun p_patCon env pc =
+    case pc of
+        PConVar n =>
+        ((if !debug then
+              string (#1 (E.lookupENamed env n) ^ "__" ^ Int.toString n)
+          else
+              string (#1 (E.lookupENamed env n)))
+         handle E.UnboundRel _ => string ("UNBOUND_NAMED" ^ Int.toString n))
+      | PConProj (m1, ms, x) =>
+        let
+            val m1x = #1 (E.lookupStrNamed env m1)
+                handle E.UnboundNamed _ => "UNBOUND_STR_" ^ Int.toString m1
+                  
+            val m1s = if !debug then
+                          m1x ^ "__" ^ Int.toString m1
+                      else
+                          m1x
+        in
+            p_list_sep (string ".") string (m1x :: ms @ [x])
+        end
+
+fun p_pat' par env (p, _) =
+    case p of
+        PWild => string "_"
+      | PVar s => string s
+      | PCon (pc, NONE) => p_patCon env pc
+      | PCon (pc, SOME p) => parenIf par (box [p_patCon env pc,
+                                               space,
+                                               p_pat' true env p])
+
+val p_pat = p_pat' false
+
 fun p_exp' par env (e, _) =
     case e of
         EPrim p => Prim.p_t p
@@ -297,6 +329,19 @@
                               p_con' true env c])
       | EFold _ => string "fold"
 
+      | ECase (e, pes, _) => parenIf par (box [string "case",
+                                               space,
+                                               p_exp env e,
+                                               space,
+                                               string "of",
+                                               space,
+                                               p_list_sep (box [space, string "|", space])
+                                                          (fn (p, e) => box [p_pat env p,
+                                                                             space,
+                                                                             string "=>",
+                                                                             space,
+                                                                             p_exp env e]) pes])
+
       | EError => string "<ERROR>"
 
 and p_exp env = p_exp' false env
--- a/src/elab_util.sml	Tue Jul 29 16:38:15 2008 -0400
+++ b/src/elab_util.sml	Thu Jul 31 10:06:27 2008 -0400
@@ -308,6 +308,17 @@
                          fn k' =>
                             (EFold k', loc))
 
+              | ECase (e, pes, t) =>
+                S.bind2 (mfe ctx e,
+                         fn e' =>
+                            S.bind2 (ListUtil.mapfold (fn (p, e) =>
+                                                         S.map2 (mfe ctx e,
+                                                              fn e' => (p, e'))) pes,
+                                    fn pes' =>
+                                       S.map2 (mfc ctx t,
+                                               fn t' =>
+                                                  (ECase (e', pes', t'), loc))))
+
               | EError => S.return2 eAll
     in
         mfe
--- a/src/elaborate.sml	Tue Jul 29 16:38:15 2008 -0400
+++ b/src/elaborate.sml	Thu Jul 31 10:06:27 2008 -0400
@@ -809,6 +809,11 @@
      | Unif of string * L'.con
      | WrongForm of string * L'.exp * L'.con
      | IncompatibleCons of L'.con * L'.con
+     | DuplicatePatternVariable of ErrorMsg.span * string
+     | PatUnify of L'.pat * L'.con * L'.con * cunify_error
+     | UnboundConstructor of ErrorMsg.span * string
+     | PatHasArg of ErrorMsg.span
+     | PatHasNoArg of ErrorMsg.span
 
 fun expError env err =
     case err of
@@ -833,6 +838,20 @@
         (ErrorMsg.errorAt (#2 c1) "Incompatible constructors";
          eprefaces' [("Con 1", p_con env c1),
                      ("Con 2", p_con env c2)])
+      | DuplicatePatternVariable (loc, s) =>
+        ErrorMsg.errorAt loc ("Duplicate pattern variable " ^ s)
+      | PatUnify (p, c1, c2, uerr) =>
+        (ErrorMsg.errorAt (#2 p) "Unification failure for pattern";
+         eprefaces' [("Pattern", p_pat env p),
+                     ("Have con", p_con env c1),
+                     ("Need con", p_con env c2)];
+         cunifyError env uerr)
+      | UnboundConstructor (loc, s) =>
+        ErrorMsg.errorAt loc ("Unbound constructor " ^ s ^ " in pattern")
+      | PatHasArg loc =>
+        ErrorMsg.errorAt loc "Constructor expects no argument but is used with argument"
+      | PatHasNoArg loc =>
+        ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument"
 
 fun checkCon (env, denv) e c1 c2 =
     unifyCons (env, denv) c1 c2
@@ -840,6 +859,12 @@
            (expError env (Unify (e, c1, c2, err));
             [])
 
+fun checkPatCon (env, denv) p c1 c2 =
+    unifyCons (env, denv) c1 c2
+    handle CUnify (c1, c2, err) =>
+           (expError env (PatUnify (p, c1, c2, err));
+            [])
+
 fun primType env p =
     case p of
         P.Int _ => !int
@@ -903,6 +928,8 @@
       | L'.ECut (_, _, {rest, ...}) => (L'.TRecord rest, loc)
       | L'.EFold dom => foldType (dom, loc)
 
+      | L'.ECase (_, _, t) => t
+
       | L'.EError => cerror
 
 fun elabHead (env, denv) (e as (_, loc)) t =
@@ -927,6 +954,52 @@
         unravel (t, e)
     end
 
+fun elabPat (pAll as (p, loc), (env, bound)) =
+    let
+        val perror = (L'.PWild, loc)
+        val terror = (L'.CError, loc)
+        val pterror = (perror, terror)
+        val rerror = (pterror, (env, bound))
+
+        fun pcon (pc, po, to, dn) =
+
+                case (po, to) of
+                    (NONE, SOME _) => (expError env (PatHasNoArg loc);
+                                       rerror)
+                  | (SOME _, NONE) => (expError env (PatHasArg loc);
+                                       rerror)
+                  | (NONE, NONE) => (((L'.PCon (pc, NONE), loc), (L'.CNamed dn, loc)),
+                                     (env, bound))
+                  | (SOME p, SOME t) =>
+                    let
+                        val ((p', pt), (env, bound)) = elabPat (p, (env, bound))
+                    in
+                        (((L'.PCon (pc, SOME p'), loc), (L'.CNamed dn, loc)),
+                         (env, bound))
+                    end
+    in
+        case p of
+            L.PWild => (((L'.PWild, loc), cunif (loc, (L'.KType, loc))),
+                        (env, bound))
+          | L.PVar x =>
+            let
+                val t = if SS.member (bound, x) then
+                            (expError env (DuplicatePatternVariable (loc, x));
+                             terror)
+                        else
+                            cunif (loc, (L'.KType, loc))
+            in
+                (((L'.PVar x, loc), t),
+                 (E.pushERel env x t, SS.add (bound, x)))
+            end
+          | L.PCon ([], x, po) =>
+            (case E.lookupConstructor env x of
+                 NONE => (expError env (UnboundConstructor (loc, x));
+                          rerror)
+               | SOME (n, to, dn) => pcon (L'.PConVar n, po, to, dn))
+          | L.PCon _ => raise Fail "uhoh"
+    end
+
 fun elabExp (env, denv) (eAll as (e, loc)) =
     let
         
@@ -1138,7 +1211,25 @@
                 ((L'.EFold dom, loc), foldType (dom, loc), [])
             end
 
-          | L.ECase _ => raise Fail "Elaborate ECase"
+          | L.ECase (e, pes) =>
+            let
+                val (e', et, gs1) = elabExp (env, denv) e
+                val result = cunif (loc, (L'.KType, loc))
+                val (pes', gs) = ListUtil.foldlMap
+                                 (fn ((p, e), gs) =>
+                                     let
+                                         val ((p', pt), (env, _)) = elabPat (p, (env, SS.empty))
+
+                                         val gs1 = checkPatCon (env, denv) p' pt et
+                                         val (e', et, gs2) = elabExp (env, denv) e
+                                         val gs3 = checkCon (env, denv) e' et result
+                                     in
+                                         ((p', e'), gs1 @ gs2 @ gs3 @ gs)
+                                     end)
+                                 gs1 pes
+            in
+                ((L'.ECase (e', pes', result), loc), result, gs)
+            end
     end
             
 
@@ -1961,6 +2052,8 @@
                         ((x, n', to), (SS.add (used, x), env, gs'))
                     end)
                 (SS.empty, env, []) xcs
+
+            val env = E.pushDatatype env n xcs
         in
             ([(L'.DDatatype (x, n, xcs), loc)], (env, denv, gs))
         end
--- a/src/explify.sml	Tue Jul 29 16:38:15 2008 -0400
+++ b/src/explify.sml	Thu Jul 31 10:06:27 2008 -0400
@@ -89,6 +89,8 @@
                                                      {field = explifyCon field, rest = explifyCon rest}), loc)
       | L.EFold k => (L'.EFold (explifyKind k), loc)
 
+      | L.ECase _ => raise Fail "Explify ECase"
+
       | L.EError => raise Fail ("explifyExp: EError at " ^ EM.spanToString loc)
 
 fun explifySgi (sgi, loc) =
--- a/src/source_print.sml	Tue Jul 29 16:38:15 2008 -0400
+++ b/src/source_print.sml	Thu Jul 31 10:06:27 2008 -0400
@@ -252,7 +252,7 @@
 
       | ECase (e, pes) => parenIf par (box [string "case",
                                             space,
-                                            p_exp' false e,
+                                            p_exp e,
                                             space,
                                             string "of",
                                             space,
--- a/tests/case.lac	Tue Jul 29 16:38:15 2008 -0400
+++ b/tests/case.lac	Thu Jul 31 10:06:27 2008 -0400
@@ -8,5 +8,9 @@
 
 datatype nat = O | S of nat
 
-val is_two = fn x : int_list =>
+val is_two = fn x : nat =>
         case x of S (S O) => A | _ => B
+
+val zero_is_two = is_two O
+val one_is_two = is_two (S O)
+val two_is_two = is_two (S (S O))