목차

(old) disjointset.py

DisjointSet v1.0

코드

# N DisjointSet
# I {"version": "1.0", "typing": ["TypeVar"], "const": ["T"]}
class DisjointSet:
    """Disjoint Set with path compression and union-by-size."""

    def __init__(self):
        self._parent = {}
        self._size = {}

    def union(self, x: T, y: T) -> T:
        root_x, root_y = self.find(x), self.find(y)
        if root_x != root_y:
            if self._size[root_x] < self._size[root_y]:
                root_x, root_y = root_y, root_x
            self._parent[root_y] = root_x
            self._size[root_x] += self._size[root_y]
        return root_x

    def find(self, x: T) -> T:
        try:
            while (p := self._parent[x]) != x:
                x, self._parent[x] = p, self._parent[p]
        except KeyError:
            self._parent[x] = x
            self._size[x] = 1
        return x

    def size(self, x: T) -> int:
        return self._size[self.find(x)]