Mercurial > urweb
diff src/elaborate.sml @ 805:e2780d2f4afc
Mutual datatypes through Elaborate
author | Adam Chlipala <adamc@hcoop.net> |
---|---|
date | Sat, 16 May 2009 15:14:17 -0400 |
parents | 9330ba3a2799 |
children | cb30dd2ba353 |
line wrap: on
line diff
--- a/src/elaborate.sml Sat May 16 13:10:52 2009 -0400 +++ b/src/elaborate.sml Sat May 16 15:14:17 2009 -0400 @@ -1971,47 +1971,65 @@ ([(L'.SgiCon (x, n, k', c'), loc)], (env', denv, gs' @ gs)) end - | L.SgiDatatype (x, xs, xcs) => + | L.SgiDatatype dts => let val k = (L'.KType, loc) - val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs - val (env, n) = E.pushCNamed env x k' NONE - val t = (L'.CNamed n, loc) - val nxs = length xs - 1 - val t = ListUtil.foldli (fn (i, _, t) => (L'.CApp (t, (L'.CRel (nxs - i), loc)), loc)) t xs - - val (env', denv') = foldl (fn (x, (env', denv')) => - (E.pushCRel env' x k, - D.enter denv')) (env, denv) xs - - val (xcs, (used, env, gs)) = - ListUtil.foldlMap - (fn ((x, to), (used, env, gs)) => - let - val (to, t, gs') = case to of - NONE => (NONE, t, gs) - | SOME t' => - let - val (t', tk, gs') = elabCon (env', denv') t' - in - checkKind env' t' tk k; - (SOME t', (L'.TFun (t', t), loc), gs' @ gs) - end - val t = foldl (fn (x, t) => (L'.TCFun (L'.Implicit, x, k, t), loc)) t xs - - val (env, n') = E.pushENamed env x t - in - if SS.member (used, x) then - strError env (DuplicateConstructor (x, loc)) - else - (); - ((x, n', to), (SS.add (used, x), env, gs')) - end) - (SS.empty, env, []) xcs - - val env = E.pushDatatype env n xs xcs + + val (dts, env) = ListUtil.foldlMap (fn ((x, xs, xcs), env) => + let + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs + val (env, n) = E.pushCNamed env x k' NONE + in + ((x, n, xs, xcs), env) + end) + env dts + + val (dts, env) = ListUtil.foldlMap + (fn ((x, n, xs, xcs), env) => + let + val t = (L'.CNamed n, loc) + val nxs = length xs - 1 + val t = ListUtil.foldli (fn (i, _, t) => + (L'.CApp (t, (L'.CRel (nxs - i), loc)), loc)) t xs + + val (env', denv') = foldl (fn (x, (env', denv')) => + (E.pushCRel env' x k, + D.enter denv')) (env, denv) xs + + val (xcs, (used, env, gs)) = + ListUtil.foldlMap + (fn ((x, to), (used, env, gs)) => + let + val (to, t, gs') = case to of + NONE => (NONE, t, gs) + | SOME t' => + let + val (t', tk, gs') = + elabCon (env', denv') t' + in + checkKind env' t' tk k; + (SOME t', + (L'.TFun (t', t), loc), + gs' @ gs) + end + val t = foldl (fn (x, t) => (L'.TCFun (L'.Implicit, x, k, t), loc)) + t xs + + val (env, n') = E.pushENamed env x t + in + if SS.member (used, x) then + strError env (DuplicateConstructor (x, loc)) + else + (); + ((x, n', to), (SS.add (used, x), env, gs')) + end) + (SS.empty, env, []) xcs + in + ((x, n, xs, xcs), E.pushDatatype env n xs xcs) + end) + env dts in - ([(L'.SgiDatatype (x, n, xs, xcs), loc)], (env, denv, gs)) + ([(L'.SgiDatatype dts, loc)], (env, denv, gs)) end | L.SgiDatatypeImp (_, [], _) => raise Fail "Empty SgiDatatypeImp" @@ -2199,21 +2217,31 @@ else (); (SS.add (cons, x), vals, sgns, strs)) - | L'.SgiDatatype (x, _, _, xncs) => + | L'.SgiDatatype dts => let - val vals = foldl (fn ((x, _, _), vals) => - (if SS.member (vals, x) then - sgnError env (DuplicateVal (loc, x)) - else - (); - SS.add (vals, x))) - vals xncs + val (cons, vals) = + let + fun doOne ((x, _, _, xncs), (cons, vals)) = + let + val vals = foldl (fn ((x, _, _), vals) => + (if SS.member (vals, x) then + sgnError env (DuplicateVal (loc, x)) + else + (); + SS.add (vals, x))) + vals xncs + in + if SS.member (cons, x) then + sgnError env (DuplicateCon (loc, x)) + else + (); + (SS.add (cons, x), vals) + end + in + foldl doOne (cons, vals) dts + end in - if SS.member (cons, x) then - sgnError env (DuplicateCon (loc, x)) - else - (); - (SS.add (cons, x), vals, sgns, strs) + (cons, vals, sgns, strs) end | L'.SgiDatatypeImp (x, _, _, _, _, _, _) => (if SS.member (cons, x) then @@ -2318,15 +2346,15 @@ | L'.SgnVar _ => sgn | L'.SgnConst sgis => - (L'.SgnConst (map (fn (L'.SgiConAbs (x, n, k), loc) => - (L'.SgiCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc) - | (L'.SgiDatatype (x, n, xs, xncs), loc) => - (L'.SgiDatatypeImp (x, n, str, strs, x, xs, xncs), loc) + (L'.SgnConst (ListUtil.mapConcat (fn (L'.SgiConAbs (x, n, k), loc) => + [(L'.SgiCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc)] + | (L'.SgiDatatype dts, loc) => + map (fn (x, n, xs, xncs) => (L'.SgiDatatypeImp (x, n, str, strs, x, xs, xncs), loc)) dts | (L'.SgiClassAbs (x, n, k), loc) => - (L'.SgiClass (x, n, k, (L'.CModProj (str, strs, x), loc)), loc) + [(L'.SgiClass (x, n, k, (L'.CModProj (str, strs, x), loc)), loc)] | (L'.SgiStr (x, n, sgn), loc) => - (L'.SgiStr (x, n, selfify env {str = str, strs = strs @ [x], sgn = sgn}), loc) - | x => x) sgis), #2 sgn) + [(L'.SgiStr (x, n, selfify env {str = str, strs = strs @ [x], sgn = sgn}), loc)] + | x => [x]) sgis), #2 sgn) | L'.SgnFun _ => sgn | L'.SgnWhere _ => sgn | L'.SgnProj (m, ms, x) => @@ -2360,46 +2388,47 @@ in case #1 (hnormSgn env sgn) of L'.SgnConst sgis => - ListUtil.foldlMap (fn ((sgi, loc), env') => - let - val d = - case sgi of - L'.SgiConAbs (x, n, k) => - let - val c = (L'.CModProj (str, strs, x), loc) - in - (L'.DCon (x, n, k, c), loc) - end - | L'.SgiCon (x, n, k, c) => - (L'.DCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc) - | L'.SgiDatatype (x, n, xs, xncs) => - (L'.DDatatypeImp (x, n, str, strs, x, xs, xncs), loc) - | L'.SgiDatatypeImp (x, n, m1, ms, x', xs, xncs) => - (L'.DDatatypeImp (x, n, m1, ms, x', xs, xncs), loc) - | L'.SgiVal (x, n, t) => - (L'.DVal (x, n, t, (L'.EModProj (str, strs, x), loc)), loc) - | L'.SgiStr (x, n, sgn) => - (L'.DStr (x, n, sgn, (L'.StrProj (m, x), loc)), loc) - | L'.SgiSgn (x, n, sgn) => - (L'.DSgn (x, n, (L'.SgnProj (str, strs, x), loc)), loc) - | L'.SgiConstraint (c1, c2) => - (L'.DConstraint (c1, c2), loc) - | L'.SgiClassAbs (x, n, k) => - let - val c = (L'.CModProj (str, strs, x), loc) - in - (L'.DCon (x, n, k, c), loc) - end - | L'.SgiClass (x, n, k, _) => - let - val c = (L'.CModProj (str, strs, x), loc) - in - (L'.DCon (x, n, k, c), loc) - end - in - (d, E.declBinds env' d) - end) - env sgis + ListUtil.foldlMapConcat + (fn ((sgi, loc), env') => + let + val d = + case sgi of + L'.SgiConAbs (x, n, k) => + let + val c = (L'.CModProj (str, strs, x), loc) + in + [(L'.DCon (x, n, k, c), loc)] + end + | L'.SgiCon (x, n, k, c) => + [(L'.DCon (x, n, k, (L'.CModProj (str, strs, x), loc)), loc)] + | L'.SgiDatatype dts => + map (fn (x, n, xs, xncs) => (L'.DDatatypeImp (x, n, str, strs, x, xs, xncs), loc)) dts + | L'.SgiDatatypeImp (x, n, m1, ms, x', xs, xncs) => + [(L'.DDatatypeImp (x, n, m1, ms, x', xs, xncs), loc)] + | L'.SgiVal (x, n, t) => + [(L'.DVal (x, n, t, (L'.EModProj (str, strs, x), loc)), loc)] + | L'.SgiStr (x, n, sgn) => + [(L'.DStr (x, n, sgn, (L'.StrProj (m, x), loc)), loc)] + | L'.SgiSgn (x, n, sgn) => + [(L'.DSgn (x, n, (L'.SgnProj (str, strs, x), loc)), loc)] + | L'.SgiConstraint (c1, c2) => + [(L'.DConstraint (c1, c2), loc)] + | L'.SgiClassAbs (x, n, k) => + let + val c = (L'.CModProj (str, strs, x), loc) + in + [(L'.DCon (x, n, k, c), loc)] + end + | L'.SgiClass (x, n, k, _) => + let + val c = (L'.CModProj (str, strs, x), loc) + in + [(L'.DCon (x, n, k, c), loc)] + end + in + (d, foldl (fn (d, env') => E.declBinds env' d) env' d) + end) + env sgis | _ => (strError env (UnOpenable sgn); ([], env)) end @@ -2445,12 +2474,11 @@ let (*val () = prefaces "folder" [("sgis1", p_sgn env (L'.SgnConst sgis1, loc2))]*) - fun seek p = + fun seek' f p = let fun seek env ls = case ls of - [] => (sgnError env (UnmatchedSgi sgi2All); - env) + [] => f env | h :: t => case p (env, h) of NONE => @@ -2474,6 +2502,9 @@ in seek env sgis1 end + + val seek = seek' (fn env => (sgnError env (UnmatchedSgi sgi2All); + env)) in case sgi of L'.SgiConAbs (x, n2, k2) => @@ -2498,12 +2529,23 @@ case sgi1 of L'.SgiConAbs (x', n1, k1) => found (x', n1, k1, NONE) | L'.SgiCon (x', n1, k1, c1) => found (x', n1, k1, SOME c1) - | L'.SgiDatatype (x', n1, xs, _) => + | L'.SgiDatatype dts => let val k = (L'.KType, loc) - val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs + + fun search dts = + case dts of + [] => NONE + | (x', n1, xs, _) :: dts => + let + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs + in + case found (x', n1, k', NONE) of + NONE => search dts + | x => x + end in - found (x', n1, k', NONE) + search dts end | L'.SgiDatatypeImp (x', n1, m1, ms, s, xs, _) => let @@ -2549,66 +2591,93 @@ | _ => NONE end) - | L'.SgiDatatype (x, n2, xs2, xncs2) => - seek (fn (env, sgi1All as (sgi1, _)) => - let - fun found (n1, xs1, xncs1) = - let - fun mismatched ue = - (sgnError env (SgiMismatchedDatatypes (sgi1All, sgi2All, ue)); - SOME env) - - val k = (L'.KType, loc) - val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs1 - - fun good () = - let - val env = E.sgiBinds env sgi1All - val env = if n1 = n2 then - env - else - E.pushCNamedAs env x n2 k' - (SOME (L'.CNamed n1, loc)) - in - SOME env - end - - val env = E.pushCNamedAs env x n1 k' NONE - val env = if n1 = n2 then - env - else - E.pushCNamedAs env x n2 k' (SOME (L'.CNamed n1, loc)) - val env = foldl (fn (x, env) => E.pushCRel env x k) env xs1 - fun xncBad ((x1, _, t1), (x2, _, t2)) = - String.compare (x1, x2) <> EQUAL - orelse case (t1, t2) of - (NONE, NONE) => false - | (SOME t1, SOME t2) => - (unifyCons env t1 t2; false) - | _ => true - in - (if xs1 <> xs2 - orelse length xncs1 <> length xncs2 - orelse ListPair.exists xncBad (xncs1, xncs2) then - mismatched NONE - else - good ()) - handle CUnify ue => mismatched (SOME ue) - end - in - case sgi1 of - L'.SgiDatatype (x', n1, xs, xncs1) => - if x' = x then - found (n1, xs, xncs1) + | L'.SgiDatatype dts2 => + let + fun found' (sgi1All, (x1, n1, xs1, xncs1), (x2, n2, xs2, xncs2), env) = + if x1 <> x2 then + NONE + else + let + fun mismatched ue = + (sgnError env (SgiMismatchedDatatypes (sgi1All, sgi2All, ue)); + SOME env) + + val k = (L'.KType, loc) + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs1 + + fun good () = + let + val env = E.sgiBinds env sgi1All + val env = if n1 = n2 then + env + else + E.pushCNamedAs env x1 n2 k' + (SOME (L'.CNamed n1, loc)) + in + SOME env + end + + val env = E.pushCNamedAs env x1 n1 k' NONE + val env = if n1 = n2 then + env + else + E.pushCNamedAs env x1 n2 k' (SOME (L'.CNamed n1, loc)) + val env = foldl (fn (x, env) => E.pushCRel env x k) env xs1 + fun xncBad ((x1, _, t1), (x2, _, t2)) = + String.compare (x1, x2) <> EQUAL + orelse case (t1, t2) of + (NONE, NONE) => false + | (SOME t1, SOME t2) => + (unifyCons env t1 t2; false) + | _ => true + in + (if xs1 <> xs2 + orelse length xncs1 <> length xncs2 + orelse ListPair.exists xncBad (xncs1, xncs2) then + mismatched NONE else - NONE - | L'.SgiDatatypeImp (x', n1, _, _, _, xs, xncs1) => - if x' = x then - found (n1, xs, xncs1) - else - NONE - | _ => NONE - end) + good ()) + handle CUnify ue => mismatched (SOME ue) + end + in + seek' + (fn _ => + let + fun seekOne (dt2, env) = + seek (fn (env, sgi1All as (sgi1, _)) => + case sgi1 of + L'.SgiDatatypeImp (x', n1, _, _, _, xs, xncs1) => + found' (sgi1All, (x', n1, xs, xncs1), dt2, env) + | _ => NONE) + + fun seekAll (dts, env) = + case dts of + [] => env + | dt :: dts => seekAll (dts, seekOne (dt, env)) + in + seekAll (dts2, env) + end) + (fn (env, sgi1All as (sgi1, _)) => + let + fun found dts1 = + let + fun iter (dts1, dts2, env) = + case (dts1, dts2) of + ([], []) => SOME env + | (dt1 :: dts1, dt2 :: dts2) => + (case found' (sgi1All, dt1, dt2, env) of + NONE => NONE + | SOME env => iter (dts1, dts2, env)) + | _ => NONE + in + iter (dts1, dts2, env) + end + in + case sgi1 of + L'.SgiDatatype dts1 => found dts1 + | _ => NONE + end) + end | L'.SgiDatatypeImp (x, n2, m12, ms2, s2, xs, _) => seek (fn (env, sgi1All as (sgi1, _)) => @@ -3033,58 +3102,63 @@ ([(L'.DCon (x, n, k', c'), loc)], (env', denv, enD gs' @ gs)) end - | L.DDatatype (x, xs, xcs) => + | L.DDatatype dts => let - val positive = List.all (fn (_, to) => - case to of - NONE => true - | SOME t => positive x t) xcs - val k = (L'.KType, loc) - val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs - val (env, n) = E.pushCNamed env x k' NONE - val t = (L'.CNamed n, loc) - val nxs = length xs - 1 - val t = ListUtil.foldli (fn (i, _, t) => (L'.CApp (t, (L'.CRel (nxs - i), loc)), loc)) t xs - - val (env', denv') = foldl (fn (x, (env', denv')) => - (E.pushCRel env' x k, - D.enter denv')) (env, denv) xs - - val (xcs, (used, env, gs')) = - ListUtil.foldlMap - (fn ((x, to), (used, env, gs)) => - let - val (to, t, gs') = case to of - NONE => (NONE, t, gs) - | SOME t' => - let - val (t', tk, gs') = elabCon (env', denv') t' - in - checkKind env' t' tk k; - (SOME t', (L'.TFun (t', t), loc), enD gs' @ gs) - end - val t = foldr (fn (x, t) => (L'.TCFun (L'.Implicit, x, k, t), loc)) t xs - - val (env, n') = E.pushENamed env x t - in - if SS.member (used, x) then - strError env (DuplicateConstructor (x, loc)) - else - (); - ((x, n', to), (SS.add (used, x), env, gs')) - end) - (SS.empty, env, []) xcs - - val env = E.pushDatatype env n xs xcs - val d' = (L'.DDatatype (x, n, xs, xcs), loc) + + val (dts, env) = ListUtil.foldlMap + (fn ((x, xs, xcs), env) => + let + val k' = foldl (fn (_, k') => (L'.KArrow (k, k'), loc)) k xs + val (env, n) = E.pushCNamed env x k' NONE + in + ((x, n, xs, xcs), env) + end) + env dts + + val (dts, (env, gs')) = ListUtil.foldlMap + (fn ((x, n, xs, xcs), (env, gs')) => + let + val t = (L'.CNamed n, loc) + val nxs = length xs - 1 + val t = ListUtil.foldli + (fn (i, _, t) => + (L'.CApp (t, (L'.CRel (nxs - i), loc)), loc)) t xs + + val (env', denv') = foldl (fn (x, (env', denv')) => + (E.pushCRel env' x k, + D.enter denv')) (env, denv) xs + + val (xcs, (used, env, gs')) = + ListUtil.foldlMap + (fn ((x, to), (used, env, gs)) => + let + val (to, t, gs') = case to of + NONE => (NONE, t, gs) + | SOME t' => + let + val (t', tk, gs') = elabCon (env', denv') t' + in + checkKind env' t' tk k; + (SOME t', (L'.TFun (t', t), loc), enD gs' @ gs) + end + val t = foldr (fn (x, t) => (L'.TCFun (L'.Implicit, x, k, t), loc)) t xs + + val (env, n') = E.pushENamed env x t + in + if SS.member (used, x) then + strError env (DuplicateConstructor (x, loc)) + else + (); + ((x, n', to), (SS.add (used, x), env, gs')) + end) + (SS.empty, env, gs') xcs + in + ((x, n, xs, xcs), (E.pushDatatype env n xs xcs, gs')) + end) + (env, []) dts in - (*if positive then - () - else - declError env (Nonpositive d');*) - - ([d'], (env, denv, gs' @ gs)) + ([(L'.DDatatype dts, loc)], (env, denv, gs' @ gs)) end | L.DDatatypeImp (_, [], _) => raise Fail "Empty DDatatypeImp" @@ -3484,24 +3558,31 @@ in ((L'.SgiCon (x, n, k, c), loc) :: sgis, cons, vals, sgns, strs) end - | L'.SgiDatatype (x, n, xs, xncs) => + | L'.SgiDatatype dts => let - val (cons, x) = - if SS.member (cons, x) then - (cons, "?" ^ x) - else - (SS.add (cons, x), x) - - val (xncs, vals) = - ListUtil.foldlMap - (fn ((x, n, t), vals) => - if SS.member (vals, x) then - (("?" ^ x, n, t), vals) + fun doOne ((x, n, xs, xncs), (cons, vals)) = + let + val (cons, x) = + if SS.member (cons, x) then + (cons, "?" ^ x) else - ((x, n, t), SS.add (vals, x))) - vals xncs + (SS.add (cons, x), x) + + val (xncs, vals) = + ListUtil.foldlMap + (fn ((x, n, t), vals) => + if SS.member (vals, x) then + (("?" ^ x, n, t), vals) + else + ((x, n, t), SS.add (vals, x))) + vals xncs + in + ((x, n, xs, xncs), (cons, vals)) + end + + val (dts, (cons, vals)) = ListUtil.foldlMap doOne (cons, vals) dts in - ((L'.SgiDatatype (x, n, xs, xncs), loc) :: sgis, cons, vals, sgns, strs) + ((L'.SgiDatatype dts, loc) :: sgis, cons, vals, sgns, strs) end | L'.SgiDatatypeImp (x, n, m1, ms, x', xs, xncs) => let