46 lines
1.5 KiB
Python
Raw Normal View History

2020-03-17 17:13:40 +05:30
from collections import defaultdict
from typing import Generic, TypeVar, DefaultDict, Dict
T = TypeVar('T')
class DisjointSet(Generic[T]):
"""Generic Disjoint Set implementation"""
def __init__(self):
self._parents: Dict[T, T] = {}
self._sizes: DefaultDict[T, int] = defaultdict(lambda: 1)
def find_root(self, x: T) -> T:
"""
Find the component in which x belongs in.
Uses path compression.
Amortized Time complexity per query ~O(1)
"""
node, root = None, x
while root != node: # using loop instead of recursion due to Python recursion limit
node, root = root, self._parents.get(root, root)
while x != root:
x = self._parents.get(x, x)
self._parents[x] = root # path compression
return root
def size_of(self, x: T) -> int:
"""
Find the size of the component that x belongs in
Time Complexity per query: O(1)
"""
return self._sizes[self.find_root(x)]
def merge_components(self, x: T, y: T) -> None:
"""
Merge the components in which the items x and y belong in
Amortized Time Complexity per query: ~O(1)
"""
x, y = self.find_root(x), self.find_root(y)
if self.size_of(x) < self.size_of(y):
x, y = y, x # swap so that y is always the larger component
self._parents[y] = x
self._sizes[x] += self._sizes[y]
del self._sizes[y]