Mercurial > urweb
diff src/elaborate.sml @ 191:aa54250f58ac
Parametrized datatypes through explify
author | Adam Chlipala <adamc@hcoop.net> |
---|---|
date | Fri, 08 Aug 2008 10:28:32 -0400 |
parents | 8e9f97508f0d |
children | df5fd8f6913a |
line wrap: on
line diff
--- a/src/elaborate.sml Thu Aug 07 13:09:26 2008 -0400 +++ b/src/elaborate.sml Fri Aug 08 10:28:32 2008 -0400 @@ -933,19 +933,32 @@ val pterror = (perror, terror) val rerror = (pterror, (env, bound)) - fun pcon (pc, po, to, dn, dk) = + fun pcon (pc, po, xs, to, dn, dk) = case (po, to) of (NONE, SOME _) => (expError env (PatHasNoArg loc); rerror) | (SOME _, NONE) => (expError env (PatHasArg loc); rerror) - | (NONE, NONE) => (((L'.PCon (dk, pc, NONE), loc), dn), - (env, bound)) + | (NONE, NONE) => + let + val k = (L'.KType, loc) + val unifs = map (fn _ => cunif (loc, k)) xs + val dn = foldl (fn (u, dn) => (L'.CApp (dn, u), loc)) dn unifs + in + (((L'.PCon (dk, pc, unifs, NONE), loc), dn), + (env, bound)) + end | (SOME p, SOME t) => let val ((p', pt), (env, bound)) = elabPat (p, (env, denv, bound)) + + val k = (L'.KType, loc) + val unifs = map (fn _ => cunif (loc, k)) xs + val t = ListUtil.foldli (fn (i, u, t) => subConInCon (i, u) t) t unifs + val dn = foldl (fn (u, dn) => (L'.CApp (dn, u), loc)) dn unifs in - (((L'.PCon (dk, pc, SOME p'), loc), dn), + ignore (checkPatCon (env, denv) p' pt t); + (((L'.PCon (dk, pc, unifs, SOME p'), loc), dn), (env, bound)) end in @@ -969,7 +982,7 @@ (case E.lookupConstructor env x of NONE => (expError env (UnboundConstructor (loc, [], x)); rerror) - | SOME (dk, n, to, dn) => pcon (L'.PConVar n, po, to, (L'.CNamed dn, loc), dk)) + | SOME (dk, n, xs, to, dn) => pcon (L'.PConVar n, po, xs, to, (L'.CNamed dn, loc), dk)) | L.PCon (m1 :: ms, x, po) => (case E.lookupStr env m1 of NONE => (expError env (UnboundStrInExp (loc, m1)); @@ -985,7 +998,7 @@ case E.projectConstructor env {str = str, sgn = sgn, field = x} of NONE => (expError env (UnboundConstructor (loc, m1 :: ms, x)); rerror) - | SOME (dk, _, to, dn) => pcon (L'.PConProj (n, ms, x), po, to, dn, dk) + | SOME (dk, _, xs, to, dn) => pcon (L'.PConProj (n, ms, x), po, xs, to, dn, dk) end) | L.PRecord (xps, flex) => @@ -1035,7 +1048,7 @@ in case E.projectConstructor env {str = str, sgn = sgn, field = x} of NONE => raise Fail "exhaustive: Can't project constructor" - | SOME (_, n, _, _) => n + | SOME (_, n, _, _, _) => n end fun coverage (p, _) = @@ -1043,8 +1056,8 @@ L'.PWild => Wild | L'.PVar _ => Wild | L'.PPrim _ => None - | L'.PCon (_, pc, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild)) - | L'.PCon (_, pc, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p)) + | L'.PCon (_, pc, _, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild)) + | L'.PCon (_, pc, _, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p)) | L'.PRecord xps => Record [foldl (fn ((x, p, _), fmap) => SM.insert (fmap, x, coverage p)) SM.empty xps] @@ -1158,8 +1171,13 @@ (total, gs' @ gs) end) (true, gs) cons + + fun unapp t = + case t of + L'.CApp ((t, _), _) => unapp t + | _ => t in - case t of + case unapp t of L'.CNamed n => let val dt = E.lookupDatatype env n @@ -1173,7 +1191,7 @@ in case E.projectDatatype env {str = str, sgn = sgn, field = x} of NONE => raise Fail "isTotal: Can't project datatype" - | SOME cons => dtype cons + | SOME (_, cons) => dtype cons end | L'.CError => (true, gs) | _ => raise Fail "isTotal: Not a datatype" @@ -1206,7 +1224,11 @@ (expError env (UnboundExp (loc, s)); (eerror, cerror, [])) | E.Rel (n, t) => ((L'.ERel n, loc), t, []) - | E.Named (n, t) => ((L'.ENamed n, loc), t, [])) + | E.Named (n, t) => + if Char.isUpper (String.sub (s, 0)) then + elabHead (env, denv) (L'.ENamed n, loc) t + else + ((L'.ENamed n, loc), t, [])) | L.EVar (m1 :: ms, s) => (case E.lookupStr env m1 of NONE => (expError env (UnboundStrInExp (loc, m1)); @@ -1572,11 +1594,13 @@ ([(L'.SgiCon (x, n, k', c'), loc)], (env', denv, gs' @ gs)) end - | L.SgiDatatype (x, xcs) => + | L.SgiDatatype (x, xs, xcs) => let val k = (L'.KType, loc) - val (env, n) = E.pushCNamed env x k NONE + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs + val (env, n) = E.pushCNamed env x k' NONE val t = (L'.CNamed n, loc) + val t = ListUtil.foldli (fn (i, _, t) => (L'.CApp (t, (L'.CRel i, loc)), loc)) t xs val (xcs, (used, env, gs)) = ListUtil.foldlMap @@ -1591,6 +1615,7 @@ checkKind env t' tk k; (SOME t', (L'.TFun (t', t), loc), gs' @ gs) end + val t = foldl (fn (x, t) => (L'.TCFun (L'.Implicit, x, k, t), loc)) t xs val (env, n') = E.pushENamed env x t in @@ -1601,8 +1626,10 @@ ((x, n', to), (SS.add (used, x), env, gs')) end) (SS.empty, env, []) xcs + + val env = E.pushDatatype env n xs xcs in - ([(L'.SgiDatatype (x, n, xcs), loc)], (env, denv, gs)) + ([(L'.SgiDatatype (x, n, xs, xcs), loc)], (env, denv, gs)) end | L.SgiDatatypeImp (_, [], _) => raise Fail "Empty SgiDatatypeImp" @@ -1625,12 +1652,14 @@ (case E.projectDatatype env {sgn = sgn, str = str, field = s} of NONE => (conError env (UnboundDatatype (loc, s)); ([], (env, denv, gs))) - | SOME xncs => + | SOME (xs, xncs) => let val k = (L'.KType, loc) + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs + val t = (L'.CModProj (n, ms, s), loc) - val (env, n') = E.pushCNamed env x k (SOME t) - val env = E.pushDatatype env n' xncs + val (env, n') = E.pushCNamed env x k' (SOME t) + val env = E.pushDatatype env n' xs xncs val t = (L'.CNamed n', loc) val env = foldl (fn ((x, n, to), env) => @@ -1638,11 +1667,15 @@ val t = case to of NONE => t | SOME t' => (L'.TFun (t', t), loc) + + val t = foldr (fn (x, t) => + (L'.TCFun (L'.Implicit, x, k, t), loc)) + t xs in E.pushENamedAs env x n t end) env xncs in - ([(L'.SgiDatatypeImp (x, n', n, ms, s, xncs), loc)], (env, denv, gs)) + ([(L'.SgiDatatypeImp (x, n', n, ms, s, xs, xncs), loc)], (env, denv, gs)) end) | _ => (strError env (NotDatatype loc); ([], (env, denv, []))) @@ -1720,7 +1753,7 @@ else (); (SS.add (cons, x), vals, sgns, strs)) - | L'.SgiDatatype (x, _, xncs) => + | L'.SgiDatatype (x, _, _, xncs) => let val vals = foldl (fn ((x, _, _), vals) => (if SS.member (vals, x) then @@ -1736,7 +1769,7 @@ (); (SS.add (cons, x), vals, sgns, strs) end - | L'.SgiDatatypeImp (x, _, _, _, _, _) => + | L'.SgiDatatypeImp (x, _, _, _, _, _, _) => (if SS.member (cons, x) then sgnError env (DuplicateCon (loc, x)) else @@ -1828,8 +1861,8 @@ | L'.SgnConst sgis => (L'.SgnConst (map (fn (L'.SgiConAbs (x, n, k), loc) => (L'.SgiCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc) - | (L'.SgiDatatype (x, n, xncs), loc) => - (L'.SgiDatatypeImp (x, n, str, strs, x, xncs), loc) + | (L'.SgiDatatype (x, n, xs, xncs), loc) => + (L'.SgiDatatypeImp (x, n, str, strs, x, xs, xncs), loc) | (L'.SgiStr (x, n, sgn), loc) => (L'.SgiStr (x, n, selfify env {str = str, strs = strs @ [x], sgn = sgn}), loc) | x => x) sgis), #2 sgn) @@ -1878,10 +1911,10 @@ end | L'.SgiCon (x, n, k, c) => (L'.DCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc) - | L'.SgiDatatype (x, n, xncs) => - (L'.DDatatypeImp (x, n, str, strs, x, xncs), loc) - | L'.SgiDatatypeImp (x, n, m1, ms, x', xncs) => - (L'.DDatatypeImp (x, n, m1, ms, x', xncs), loc) + | L'.SgiDatatype (x, n, xs, xncs) => + (L'.DDatatypeImp (x, n, str, strs, x, xs, xncs), loc) + | L'.SgiDatatypeImp (x, n, m1, ms, x', xs, xncs) => + (L'.DDatatypeImp (x, n, m1, ms, x', xs, xncs), loc) | L'.SgiVal (x, n, t) => (L'.DVal (x, n, t, (L'.EModProj (str, strs, x), loc)), loc) | L'.SgiStr (x, n, sgn) => @@ -1998,9 +2031,20 @@ case sgi1 of L'.SgiConAbs (x', n1, k1) => found (x', n1, k1, NONE) | L'.SgiCon (x', n1, k1, c1) => found (x', n1, k1, SOME c1) - | L'.SgiDatatype (x', n1, _) => found (x', n1, (L'.KType, loc), NONE) - | L'.SgiDatatypeImp (x', n1, m1, ms, s, _) => - found (x', n1, (L'.KType, loc), SOME (L'.CModProj (m1, ms, s), loc)) + | L'.SgiDatatype (x', n1, xs, _) => + let + val k = (L'.KType, loc) + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs + in + found (x', n1, k', NONE) + end + | L'.SgiDatatypeImp (x', n1, m1, ms, s, xs, _) => + let + val k = (L'.KType, loc) + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs + in + found (x', n1, k', SOME (L'.CModProj (m1, ms, s), loc)) + end | _ => NONE end) @@ -2023,15 +2067,18 @@ NONE | _ => NONE) - | L'.SgiDatatype (x, n2, xncs2) => + | L'.SgiDatatype (x, n2, xs2, xncs2) => seek (fn sgi1All as (sgi1, _) => let - fun found (n1, xncs1) = + fun found (n1, xs1, xncs1) = let fun mismatched ue = (sgnError env (SgiMismatchedDatatypes (sgi1All, sgi2All, ue)); SOME (env, denv)) + val k = (L'.KType, loc) + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs1 + fun good () = let val env = E.sgiBinds env sgi2All @@ -2044,6 +2091,7 @@ SOME (env, denv) end + val env = foldl (fn (x, env) => E.pushCRel env x k) env xs1 fun xncBad ((x1, _, t1), (x2, _, t2)) = String.compare (x1, x2) <> EQUAL orelse case (t1, t2) of @@ -2052,7 +2100,8 @@ not (List.null (unifyCons (env, denv) t1 t2)) | _ => true in - (if length xncs1 <> length xncs2 + (if xs1 <> xs2 + orelse length xncs1 <> length xncs2 orelse ListPair.exists xncBad (xncs1, xncs2) then mismatched NONE else @@ -2061,33 +2110,34 @@ end in case sgi1 of - L'.SgiDatatype (x', n1, xncs1) => + L'.SgiDatatype (x', n1, xs, xncs1) => if x' = x then - found (n1, xncs1) + found (n1, xs, xncs1) else NONE - | L'.SgiDatatypeImp (x', n1, _, _, _, xncs1) => + | L'.SgiDatatypeImp (x', n1, _, _, _, xs, xncs1) => if x' = x then - found (n1, xncs1) + found (n1, xs, xncs1) else NONE | _ => NONE end) - | L'.SgiDatatypeImp (x, n2, m12, ms2, s2, _) => + | L'.SgiDatatypeImp (x, n2, m12, ms2, s2, xs, _) => seek (fn sgi1All as (sgi1, _) => case sgi1 of - L'.SgiDatatypeImp (x', n1, m11, ms1, s1, _) => + L'.SgiDatatypeImp (x', n1, m11, ms1, s1, _, _) => if x = x' then let val k = (L'.KType, loc) + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs val t1 = (L'.CModProj (m11, ms1, s1), loc) val t2 = (L'.CModProj (m12, ms2, s2), loc) fun good () = let - val env = E.pushCNamedAs env x n1 k (SOME t1) - val env = E.pushCNamedAs env x n2 k (SOME t2) + val env = E.pushCNamedAs env x n1 k' (SOME t1) + val env = E.pushCNamedAs env x n2 k' (SOME t2) in SOME (env, denv) end @@ -2213,11 +2263,17 @@ ([(L'.DCon (x, n, k', c'), loc)], (env', denv, gs' @ gs)) end - | L.DDatatype (x, xcs) => + | L.DDatatype (x, xs, xcs) => let val k = (L'.KType, loc) - val (env, n) = E.pushCNamed env x k NONE + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs + val (env, n) = E.pushCNamed env x k' NONE val t = (L'.CNamed n, loc) + val t = ListUtil.foldli (fn (i, _, t) => (L'.CApp (t, (L'.CRel i, loc)), loc)) t xs + + val (env', denv') = foldl (fn (x, (env', denv')) => + (E.pushCRel env' x k, + D.enter denv')) (env, denv) xs val (xcs, (used, env, gs)) = ListUtil.foldlMap @@ -2227,11 +2283,12 @@ NONE => (NONE, t, gs) | SOME t' => let - val (t', tk, gs') = elabCon (env, denv) t' + val (t', tk, gs') = elabCon (env', denv') t' in - checkKind env t' tk k; + checkKind env' t' tk k; (SOME t', (L'.TFun (t', t), loc), gs' @ gs) end + val t = foldr (fn (x, t) => (L'.TCFun (L'.Implicit, x, k, t), loc)) t xs val (env, n') = E.pushENamed env x t in @@ -2243,9 +2300,9 @@ end) (SS.empty, env, []) xcs - val env = E.pushDatatype env n xcs + val env = E.pushDatatype env n xs xcs in - ([(L'.DDatatype (x, n, xcs), loc)], (env, denv, gs)) + ([(L'.DDatatype (x, n, xs, xcs), loc)], (env, denv, gs)) end | L.DDatatypeImp (_, [], _) => raise Fail "Empty DDatatypeImp" @@ -2268,12 +2325,13 @@ (case E.projectDatatype env {sgn = sgn, str = str, field = s} of NONE => (conError env (UnboundDatatype (loc, s)); ([], (env, denv, gs))) - | SOME xncs => + | SOME (xs, xncs) => let val k = (L'.KType, loc) + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs val t = (L'.CModProj (n, ms, s), loc) - val (env, n') = E.pushCNamed env x k (SOME t) - val env = E.pushDatatype env n' xncs + val (env, n') = E.pushCNamed env x k' (SOME t) + val env = E.pushDatatype env n' xs xncs val t = (L'.CNamed n', loc) val env = foldl (fn ((x, n, to), env) => @@ -2281,11 +2339,15 @@ val t = case to of NONE => t | SOME t' => (L'.TFun (t', t), loc) + + val t = foldr (fn (x, t) => + (L'.TCFun (L'.Implicit, x, k, t), loc)) + t xs in E.pushENamedAs env x n t end) env xncs in - ([(L'.DDatatypeImp (x, n', n, ms, s, xncs), loc)], (env, denv, gs)) + ([(L'.DDatatypeImp (x, n', n, ms, s, xs, xncs), loc)], (env, denv, gs)) end) | _ => (strError env (NotDatatype loc); ([], (env, denv, []))) @@ -2544,7 +2606,7 @@ in ((L'.SgiCon (x, n, k, c), loc) :: sgis, cons, vals, sgns, strs) end - | L'.SgiDatatype (x, n, xncs) => + | L'.SgiDatatype (x, n, xs, xncs) => let val (cons, x) = if SS.member (cons, x) then @@ -2561,9 +2623,9 @@ ((x, n, t), SS.add (vals, x))) vals xncs in - ((L'.SgiDatatype (x, n, xncs), loc) :: sgis, cons, vals, sgns, strs) + ((L'.SgiDatatype (x, n, xs, xncs), loc) :: sgis, cons, vals, sgns, strs) end - | L'.SgiDatatypeImp (x, n, m1, ms, x', xncs) => + | L'.SgiDatatypeImp (x, n, m1, ms, x', xs, xncs) => let val (cons, x) = if SS.member (cons, x) then @@ -2571,7 +2633,7 @@ else (SS.add (cons, x), x) in - ((L'.SgiDatatypeImp (x, n, m1, ms, x', xncs), loc) :: sgis, cons, vals, sgns, strs) + ((L'.SgiDatatypeImp (x, n, m1, ms, x', xs, xncs), loc) :: sgis, cons, vals, sgns, strs) end | L'.SgiVal (x, n, c) => let