comparison src/elaborate.sml @ 175:b2d752455182

Elaborating record patterns
author Adam Chlipala <adamc@hcoop.net>
date Thu, 31 Jul 2008 13:08:57 -0400
parents 7ee424760d2f
children 33d4a8eea484
comparison
equal deleted inserted replaced
174:7ee424760d2f 175:b2d752455182
36 36
37 open Print 37 open Print
38 open ElabPrint 38 open ElabPrint
39 39
40 structure IM = IntBinaryMap 40 structure IM = IntBinaryMap
41 structure SS = BinarySetFn(struct 41
42 type ord_key = string 42 structure SK = struct
43 val compare = String.compare 43 type ord_key = string
44 end) 44 val compare = String.compare
45 end
46
47 structure SS = BinarySetFn(SK)
48 structure SM = BinaryMapFn(SK)
45 49
46 fun elabExplicitness e = 50 fun elabExplicitness e =
47 case e of 51 case e of
48 L.Explicit => L'.Explicit 52 L.Explicit => L'.Explicit
49 | L.Implicit => L'.Implicit 53 | L.Implicit => L'.Implicit
814 | PatUnify of L'.pat * L'.con * L'.con * cunify_error 818 | PatUnify of L'.pat * L'.con * L'.con * cunify_error
815 | UnboundConstructor of ErrorMsg.span * string list * string 819 | UnboundConstructor of ErrorMsg.span * string list * string
816 | PatHasArg of ErrorMsg.span 820 | PatHasArg of ErrorMsg.span
817 | PatHasNoArg of ErrorMsg.span 821 | PatHasNoArg of ErrorMsg.span
818 | Inexhaustive of ErrorMsg.span 822 | Inexhaustive of ErrorMsg.span
823 | DuplicatePatField of ErrorMsg.span * string
819 824
820 fun expError env err = 825 fun expError env err =
821 case err of 826 case err of
822 UnboundExp (loc, s) => 827 UnboundExp (loc, s) =>
823 ErrorMsg.errorAt loc ("Unbound expression variable " ^ s) 828 ErrorMsg.errorAt loc ("Unbound expression variable " ^ s)
854 ErrorMsg.errorAt loc "Constructor expects no argument but is used with argument" 859 ErrorMsg.errorAt loc "Constructor expects no argument but is used with argument"
855 | PatHasNoArg loc => 860 | PatHasNoArg loc =>
856 ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument" 861 ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument"
857 | Inexhaustive loc => 862 | Inexhaustive loc =>
858 ErrorMsg.errorAt loc "Inexhaustive 'case'" 863 ErrorMsg.errorAt loc "Inexhaustive 'case'"
864 | DuplicatePatField (loc, s) =>
865 ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern")
859 866
860 fun checkCon (env, denv) e c1 c2 = 867 fun checkCon (env, denv) e c1 c2 =
861 unifyCons (env, denv) c1 c2 868 unifyCons (env, denv) c1 c2
862 handle CUnify (c1, c2, err) => 869 handle CUnify (c1, c2, err) =>
863 (expError env (Unify (e, c1, c2, err)); 870 (expError env (Unify (e, c1, c2, err));
1019 NONE => (expError env (UnboundConstructor (loc, m1 :: ms, x)); 1026 NONE => (expError env (UnboundConstructor (loc, m1 :: ms, x));
1020 rerror) 1027 rerror)
1021 | SOME (_, to, dn) => pcon (L'.PConProj (n, ms, x), po, to, dn) 1028 | SOME (_, to, dn) => pcon (L'.PConProj (n, ms, x), po, to, dn)
1022 end) 1029 end)
1023 1030
1024 | L.PRecord _ => raise Fail "Elaborate PRecord" 1031 | L.PRecord (xps, flex) =>
1032 let
1033 val (xpts, (env, bound, _)) =
1034 ListUtil.foldlMap (fn ((x, p), (env, bound, fbound)) =>
1035 let
1036 val ((p', t), (env, bound)) = elabPat (p, (env, denv, bound))
1037 in
1038 if SS.member (fbound, x) then
1039 expError env (DuplicatePatField (loc, x))
1040 else
1041 ();
1042 ((x, p', t), (env, bound, SS.add (fbound, x)))
1043 end)
1044 (env, bound, SS.empty) xps
1045
1046 val k = (L'.KType, loc)
1047 val c = (L'.CRecord (k, map (fn (x, _, t) => ((L'.CName x, loc), t)) xpts), loc)
1048 val (flex, c) =
1049 if flex then
1050 let
1051 val flex = cunif (loc, (L'.KRecord k, loc))
1052 in
1053 (SOME flex, (L'.CConcat (c, flex), loc))
1054 end
1055 else
1056 (NONE, c)
1057 in
1058 (((L'.PRecord (map (fn (x, p', _) => (x, p')) xpts, flex), loc),
1059 (L'.TRecord c, loc)),
1060 (env, bound))
1061 end
1062
1025 end 1063 end
1026 1064
1027 datatype coverage = 1065 datatype coverage =
1028 Wild 1066 Wild
1029 | None 1067 | None
1030 | Datatype of coverage IM.map 1068 | Datatype of coverage IM.map
1069 | Record of coverage SM.map list
1031 1070
1032 fun exhaustive (env, denv, t, ps) = 1071 fun exhaustive (env, denv, t, ps) =
1033 let 1072 let
1034 fun pcCoverage pc = 1073 fun pcCoverage pc =
1035 case pc of 1074 case pc of
1048 L'.PWild => Wild 1087 L'.PWild => Wild
1049 | L'.PVar _ => Wild 1088 | L'.PVar _ => Wild
1050 | L'.PPrim _ => None 1089 | L'.PPrim _ => None
1051 | L'.PCon (pc, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild)) 1090 | L'.PCon (pc, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild))
1052 | L'.PCon (pc, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p)) 1091 | L'.PCon (pc, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p))
1053 1092 | L'.PRecord (xps, _) => Record [foldl (fn ((x, p), fmap) =>
1093 SM.insert (fmap, x, coverage p)) SM.empty xps]
1054 fun merge (c1, c2) = 1094 fun merge (c1, c2) =
1055 case (c1, c2) of 1095 case (c1, c2) of
1056 (None, _) => c2 1096 (None, _) => c2
1057 | (_, None) => c1 1097 | (_, None) => c1
1058 1098
1059 | (Wild, _) => Wild 1099 | (Wild, _) => Wild
1060 | (_, Wild) => Wild 1100 | (_, Wild) => Wild
1061 1101
1062 | (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2)) 1102 | (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2))
1063 1103
1104 | (Record fm1, Record fm2) => Record (fm1 @ fm2)
1105
1106 | _ => None
1107
1064 fun combinedCoverage ps = 1108 fun combinedCoverage ps =
1065 case ps of 1109 case ps of
1066 [] => raise Fail "Empty pattern list for coverage checking" 1110 [] => raise Fail "Empty pattern list for coverage checking"
1067 | [p] => coverage p 1111 | [p] => coverage p
1068 | p :: ps => merge (coverage p, combinedCoverage ps) 1112 | p :: ps => merge (coverage p, combinedCoverage ps)
1113
1114 fun enumerateCases t =
1115 let
1116 fun dtype cons =
1117 ListUtil.mapConcat (fn (_, n, to) =>
1118 case to of
1119 NONE => [Datatype (IM.insert (IM.empty, n, Wild))]
1120 | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c)))
1121 (enumerateCases t)) cons
1122 in
1123 case #1 (#1 (hnormCon (env, denv) t)) of
1124 L'.CNamed n =>
1125 (let
1126 val dt = E.lookupDatatype env n
1127 val cons = E.constructors dt
1128 in
1129 dtype cons
1130 end handle E.UnboundNamed _ => [Wild])
1131 | L'.TRecord c =>
1132 (case #1 (#1 (hnormCon (env, denv) c)) of
1133 L'.CRecord (_, xts) =>
1134 let
1135 val xts = map (fn (x, t) => (#1 (hnormCon (env, denv) x), t)) xts
1136
1137 fun exponentiate fs =
1138 case fs of
1139 [] => [SM.empty]
1140 | ((L'.CName x, _), t) :: rest =>
1141 let
1142 val this = enumerateCases t
1143 val rest = exponentiate rest
1144 in
1145 ListUtil.mapConcat (fn fmap =>
1146 map (fn c => SM.insert (fmap, x, c)) this) rest
1147 end
1148 | _ => raise Fail "exponentiate: Not CName"
1149 in
1150 if List.exists (fn ((L'.CName _, _), _) => false
1151 | (c, _) => true) xts then
1152 [Wild]
1153 else
1154 map (fn ls => Record [ls]) (exponentiate xts)
1155 end
1156 | _ => [Wild])
1157 | _ => [Wild]
1158 end
1159
1160 fun coverageImp (c1, c2) =
1161 case (c1, c2) of
1162 (Wild, _) => true
1163
1164 | (Datatype cmap1, Datatype cmap2) =>
1165 List.all (fn (n, c2) =>
1166 case IM.find (cmap1, n) of
1167 NONE => false
1168 | SOME c1 => coverageImp (c1, c2)) (IM.listItemsi cmap2)
1169
1170 | (Record fmaps1, Record fmaps2) =>
1171 List.all (fn fmap2 =>
1172 List.exists (fn fmap1 =>
1173 List.all (fn (x, c2) =>
1174 case SM.find (fmap1, x) of
1175 NONE => true
1176 | SOME c1 => coverageImp (c1, c2))
1177 (SM.listItemsi fmap2))
1178 fmaps1) fmaps2
1179
1180 | _ => false
1069 1181
1070 fun isTotal (c, t) = 1182 fun isTotal (c, t) =
1071 case c of 1183 case c of
1072 None => (false, []) 1184 None => (false, [])
1073 | Wild => (true, []) 1185 | Wild => (true, [])
1107 | SOME cons => dtype cons 1219 | SOME cons => dtype cons
1108 end 1220 end
1109 | L'.CError => (true, gs) 1221 | L'.CError => (true, gs)
1110 | _ => raise Fail "isTotal: Not a datatype" 1222 | _ => raise Fail "isTotal: Not a datatype"
1111 end 1223 end
1224 | Record _ => (List.all (fn c2 => coverageImp (c, c2)) (enumerateCases t), [])
1112 in 1225 in
1113 isTotal (combinedCoverage ps, t) 1226 isTotal (combinedCoverage ps, t)
1114 end 1227 end
1115 1228
1116 fun elabExp (env, denv) (eAll as (e, loc)) = 1229 fun elabExp (env, denv) (eAll as (e, loc)) =