def test_len(): n = 10 elements = get_elements(n) dis = DisjointSet(elements) assert len(dis) == n dis.add("dummy") assert len(dis) == n + 1
def test_self_unions(n): elements = get_elements(n) dis = DisjointSet(elements) for x in elements: assert dis.connected(x, x) assert not dis.merge(x, x) assert dis.connected(x, x) assert dis.n_subsets == len(elements) assert elements == list(dis) roots = [dis[x] for x in elements] assert elements == roots
class Bench(Benchmark): params = [[100, 1000, 10000]] param_names = ['n'] def setup(self, n): # Create random edges rng = np.random.RandomState(seed=0) self.edges = rng.randint(0, 10 * n, (n, 2)) self.nodes = np.unique(self.edges) self.disjoint_set = DisjointSet(self.nodes) self.pre_merged = DisjointSet(self.nodes) for a, b in self.edges: self.pre_merged.merge(a, b) self.pre_merged_found = DisjointSet(self.nodes) for a, b in self.edges: self.pre_merged_found.merge(a, b) for x in self.nodes: self.pre_merged_found[x] def time_merge(self, n): dis = self.disjoint_set for a, b in self.edges: dis.merge(a, b) def time_merge_already_merged(self, n): dis = self.pre_merged for a, b in self.edges: dis.merge(a, b) def time_find(self, n): dis = self.pre_merged return [dis[i] for i in self.nodes] def time_find_already_found(self, n): dis = self.pre_merged_found return [dis[i] for i in self.nodes] def time_contains(self, n): assert self.nodes[0] in self.pre_merged assert self.nodes[n // 2] in self.pre_merged assert self.nodes[-1] in self.pre_merged def time_absence(self, n): # Test for absence assert None not in self.pre_merged assert "dummy" not in self.pre_merged assert (1, 2, 3) not in self.pre_merged
def test_contains(n): elements = get_elements(n) dis = DisjointSet(elements) for x in elements: assert x in dis assert "dummy" not in dis
def setup(self, n): # Create random edges rng = np.random.RandomState(seed=0) self.edges = rng.randint(0, 10 * n, (n, 2)) self.nodes = np.unique(self.edges) self.disjoint_set = DisjointSet(self.nodes) self.pre_merged = DisjointSet(self.nodes) for a, b in self.edges: self.pre_merged.merge(a, b) self.pre_merged_found = DisjointSet(self.nodes) for a, b in self.edges: self.pre_merged_found.merge(a, b) for x in self.nodes: self.pre_merged_found[x]
def test_equal_size_ordering(n, order): elements = get_elements(n) dis = DisjointSet(elements) rng = np.random.RandomState(seed=0) indices = np.arange(n) rng.shuffle(indices) for i in range(0, len(indices), 2): a, b = elements[indices[i]], elements[indices[i + 1]] if order == "ab": assert dis.merge(a, b) else: assert dis.merge(b, a) expected = elements[min(indices[i], indices[i + 1])] assert dis[a] == expected assert dis[b] == expected
def test_binary_tree(kmax): n = 2**kmax elements = get_elements(n) dis = DisjointSet(elements) rng = np.random.RandomState(seed=0) for k in 2**np.arange(kmax): for i in range(0, n, 2 * k): r1, r2 = rng.randint(0, k, size=2) a, b = elements[i + r1], elements[i + k + r2] assert not dis.connected(a, b) assert dis.merge(a, b) assert dis.connected(a, b) assert elements == list(dis) roots = [dis[i] for i in elements] expected_indices = np.arange(n) - np.arange(n) % (2 * k) expected = [elements[i] for i in expected_indices] assert roots == expected
def day9b(s): from scipy.cluster.hierarchy import DisjointSet grid = np.array([[int(a) for a in line] for line in s.splitlines()], dtype=int) ymax, xmax = grid.shape loc = [(y, x) for y in range(ymax) for x in range(xmax) if grid[y, x] < 9] basins = DisjointSet(loc) for y, x in loc: neighbors = [(grid[y + yd, x + xd], y + yd, x + xd) for yd, xd in ((-1, 0), (0, 1), (0, -1), (1, 0)) if 0 <= y + yd < ymax and 0 <= x + xd < xmax and grid[y, x] < 9] val, yy, xx = min(neighbors, default=(999, -1, -1)) if grid[y, x] > val: basins.merge((y, x), (yy, xx)) a, b, c = sorted(basins.subsets(), key=len)[-3:] return len(a) * len(b) * len(c)
def test_add(n): elements = get_elements(n) dis1 = DisjointSet(elements) dis2 = DisjointSet() for i, x in enumerate(elements): dis2.add(x) assert len(dis2) == i + 1 # test idempotency by adding element again dis2.add(x) assert len(dis2) == i + 1 assert list(dis1) == list(dis2)
def test_element_not_present(): elements = get_elements(n=10) dis = DisjointSet(elements) with assert_raises(KeyError): dis["dummy"] with assert_raises(KeyError): dis.merge(elements[0], "dummy") with assert_raises(KeyError): dis.connected(elements[0], "dummy")
def test_subsets(n): elements = get_elements(n) dis = DisjointSet(elements) rng = np.random.RandomState(seed=0) for i, j in rng.randint(0, n, (n, 2)): x = elements[i] y = elements[j] expected = {element for element in dis if {dis[element]} == {dis[x]}} assert expected == dis.subset(x) expected = {dis[element]: set() for element in dis} for element in dis: expected[dis[element]].add(element) expected = list(expected.values()) assert expected == dis.subsets() dis.merge(x, y) assert dis.subset(x) == dis.subset(y)
def test_linear_union_sequence(n, direction): elements = get_elements(n) dis = DisjointSet(elements) assert elements == list(dis) indices = list(range(n - 1)) if direction == "backwards": indices = indices[::-1] for it, i in enumerate(indices): assert not dis.connected(elements[i], elements[i + 1]) assert dis.merge(elements[i], elements[i + 1]) assert dis.connected(elements[i], elements[i + 1]) assert dis.n_subsets == n - 1 - it roots = [dis[i] for i in elements] if direction == "forwards": assert all(elements[0] == r for r in roots) else: assert all(elements[-2] == r for r in roots) assert not dis.merge(elements[0], elements[-1])
def test_init(): n = 10 elements = get_elements(n) dis = DisjointSet(elements) assert dis.n_subsets == n assert list(dis) == elements