ziv@2219: functor UnionFindFn(K : ORD_KEY) :> sig ziv@2219: type unionFind ziv@2219: val empty : unionFind ziv@2219: val union : unionFind * K.ord_key * K.ord_key -> unionFind ziv@2219: val union' : (K.ord_key * K.ord_key) * unionFind -> unionFind ziv@2219: val classes : unionFind -> K.ord_key list list ziv@2219: end = struct ziv@2216: ziv@2216: structure M = BinaryMapFn(K) ziv@2216: structure S = BinarySetFn(K) ziv@2216: ziv@2216: datatype entry = ziv@2216: Set of S.set ziv@2216: | Pointer of K.ord_key ziv@2216: ziv@2216: (* First map is the union-find tree, second stores equivalence classes. *) ziv@2216: type unionFind = entry M.map ref * S.set M.map ziv@2216: ziv@2216: val empty : unionFind = (ref M.empty, M.empty) ziv@2216: ziv@2216: fun findPair (uf, x) = ziv@2216: case M.find (!uf, x) of ziv@2216: NONE => (S.singleton x, x) ziv@2216: | SOME (Set set) => (set, x) ziv@2216: | SOME (Pointer parent) => ziv@2216: let ziv@2216: val (set, rep) = findPair (uf, parent) ziv@2216: in ziv@2216: uf := M.insert (!uf, x, Pointer rep); ziv@2216: (set, rep) ziv@2216: end ziv@2216: ziv@2216: fun find ((uf, _), x) = (S.listItems o #1 o findPair) (uf, x) ziv@2216: ziv@2216: fun classes (_, cs) = (map S.listItems o M.listItems) cs ziv@2216: ziv@2216: fun union ((uf, cs), x, y) = ziv@2216: let ziv@2216: val (xSet, xRep) = findPair (uf, x) ziv@2216: val (ySet, yRep) = findPair (uf, y) ziv@2216: val xySet = S.union (xSet, ySet) ziv@2216: in ziv@2216: (ref (M.insert (M.insert (!uf, yRep, Pointer xRep), ziv@2216: xRep, Set xySet)), ziv@2216: M.insert (case M.find (cs, yRep) of ziv@2216: NONE => cs ziv@2216: | SOME _ => #1 (M.remove (cs, yRep)), ziv@2216: xRep, xySet)) ziv@2216: end ziv@2216: ziv@2216: fun union' ((x, y), uf) = union (uf, x, y) ziv@2216: ziv@2216: end