コード例 #1
0
def test_len():
    n = 10
    elements = get_elements(n)
    dis = DisjointSet(elements)
    assert len(dis) == n

    dis.add("dummy")
    assert len(dis) == n + 1
コード例 #2
0
def test_self_unions(n):
    elements = get_elements(n)
    dis = DisjointSet(elements)

    for x in elements:
        assert dis.connected(x, x)
        assert not dis.merge(x, x)
        assert dis.connected(x, x)
    assert dis.n_subsets == len(elements)

    assert elements == list(dis)
    roots = [dis[x] for x in elements]
    assert elements == roots
コード例 #3
0
class Bench(Benchmark):
    params = [[100, 1000, 10000]]
    param_names = ['n']

    def setup(self, n):
        # Create random edges
        rng = np.random.RandomState(seed=0)
        self.edges = rng.randint(0, 10 * n, (n, 2))
        self.nodes = np.unique(self.edges)
        self.disjoint_set = DisjointSet(self.nodes)

        self.pre_merged = DisjointSet(self.nodes)
        for a, b in self.edges:
            self.pre_merged.merge(a, b)

        self.pre_merged_found = DisjointSet(self.nodes)
        for a, b in self.edges:
            self.pre_merged_found.merge(a, b)
        for x in self.nodes:
            self.pre_merged_found[x]

    def time_merge(self, n):
        dis = self.disjoint_set
        for a, b in self.edges:
            dis.merge(a, b)

    def time_merge_already_merged(self, n):
        dis = self.pre_merged
        for a, b in self.edges:
            dis.merge(a, b)

    def time_find(self, n):
        dis = self.pre_merged
        return [dis[i] for i in self.nodes]

    def time_find_already_found(self, n):
        dis = self.pre_merged_found
        return [dis[i] for i in self.nodes]

    def time_contains(self, n):
        assert self.nodes[0] in self.pre_merged
        assert self.nodes[n // 2] in self.pre_merged
        assert self.nodes[-1] in self.pre_merged

    def time_absence(self, n):
        # Test for absence
        assert None not in self.pre_merged
        assert "dummy" not in self.pre_merged
        assert (1, 2, 3) not in self.pre_merged
コード例 #4
0
def test_contains(n):
    elements = get_elements(n)
    dis = DisjointSet(elements)
    for x in elements:
        assert x in dis

    assert "dummy" not in dis
コード例 #5
0
    def setup(self, n):
        # Create random edges
        rng = np.random.RandomState(seed=0)
        self.edges = rng.randint(0, 10 * n, (n, 2))
        self.nodes = np.unique(self.edges)
        self.disjoint_set = DisjointSet(self.nodes)

        self.pre_merged = DisjointSet(self.nodes)
        for a, b in self.edges:
            self.pre_merged.merge(a, b)

        self.pre_merged_found = DisjointSet(self.nodes)
        for a, b in self.edges:
            self.pre_merged_found.merge(a, b)
        for x in self.nodes:
            self.pre_merged_found[x]
コード例 #6
0
def test_equal_size_ordering(n, order):
    elements = get_elements(n)
    dis = DisjointSet(elements)

    rng = np.random.RandomState(seed=0)
    indices = np.arange(n)
    rng.shuffle(indices)

    for i in range(0, len(indices), 2):
        a, b = elements[indices[i]], elements[indices[i + 1]]
        if order == "ab":
            assert dis.merge(a, b)
        else:
            assert dis.merge(b, a)

        expected = elements[min(indices[i], indices[i + 1])]
        assert dis[a] == expected
        assert dis[b] == expected
コード例 #7
0
def test_binary_tree(kmax):
    n = 2**kmax
    elements = get_elements(n)
    dis = DisjointSet(elements)
    rng = np.random.RandomState(seed=0)

    for k in 2**np.arange(kmax):
        for i in range(0, n, 2 * k):
            r1, r2 = rng.randint(0, k, size=2)
            a, b = elements[i + r1], elements[i + k + r2]
            assert not dis.connected(a, b)
            assert dis.merge(a, b)
            assert dis.connected(a, b)

        assert elements == list(dis)
        roots = [dis[i] for i in elements]
        expected_indices = np.arange(n) - np.arange(n) % (2 * k)
        expected = [elements[i] for i in expected_indices]
        assert roots == expected
コード例 #8
0
def day9b(s):
	from scipy.cluster.hierarchy import DisjointSet
	grid = np.array([[int(a) for a in line] for line in s.splitlines()],
			dtype=int)
	ymax, xmax = grid.shape
	loc = [(y, x) for y in range(ymax)
			for x in range(xmax)
			if grid[y, x] < 9]
	basins = DisjointSet(loc)
	for y, x in loc:
		neighbors = [(grid[y + yd, x + xd], y + yd, x + xd)
				for yd, xd in ((-1, 0), (0, 1), (0, -1), (1, 0))
				if 0 <= y + yd < ymax and 0 <= x + xd < xmax
					and grid[y, x] < 9]
		val, yy, xx = min(neighbors, default=(999, -1, -1))
		if grid[y, x] > val:
			basins.merge((y, x), (yy, xx))
	a, b, c = sorted(basins.subsets(), key=len)[-3:]
	return len(a) * len(b) * len(c)
コード例 #9
0
def test_add(n):
    elements = get_elements(n)
    dis1 = DisjointSet(elements)

    dis2 = DisjointSet()
    for i, x in enumerate(elements):
        dis2.add(x)
        assert len(dis2) == i + 1

        # test idempotency by adding element again
        dis2.add(x)
        assert len(dis2) == i + 1

    assert list(dis1) == list(dis2)
コード例 #10
0
def test_element_not_present():
    elements = get_elements(n=10)
    dis = DisjointSet(elements)

    with assert_raises(KeyError):
        dis["dummy"]

    with assert_raises(KeyError):
        dis.merge(elements[0], "dummy")

    with assert_raises(KeyError):
        dis.connected(elements[0], "dummy")
コード例 #11
0
def test_subsets(n):
    elements = get_elements(n)
    dis = DisjointSet(elements)

    rng = np.random.RandomState(seed=0)
    for i, j in rng.randint(0, n, (n, 2)):
        x = elements[i]
        y = elements[j]

        expected = {element for element in dis if {dis[element]} == {dis[x]}}
        assert expected == dis.subset(x)

        expected = {dis[element]: set() for element in dis}
        for element in dis:
            expected[dis[element]].add(element)
        expected = list(expected.values())
        assert expected == dis.subsets()

        dis.merge(x, y)
        assert dis.subset(x) == dis.subset(y)
コード例 #12
0
def test_linear_union_sequence(n, direction):
    elements = get_elements(n)
    dis = DisjointSet(elements)
    assert elements == list(dis)

    indices = list(range(n - 1))
    if direction == "backwards":
        indices = indices[::-1]

    for it, i in enumerate(indices):
        assert not dis.connected(elements[i], elements[i + 1])
        assert dis.merge(elements[i], elements[i + 1])
        assert dis.connected(elements[i], elements[i + 1])
        assert dis.n_subsets == n - 1 - it

    roots = [dis[i] for i in elements]
    if direction == "forwards":
        assert all(elements[0] == r for r in roots)
    else:
        assert all(elements[-2] == r for r in roots)
    assert not dis.merge(elements[0], elements[-1])
コード例 #13
0
def test_init():
    n = 10
    elements = get_elements(n)
    dis = DisjointSet(elements)
    assert dis.n_subsets == n
    assert list(dis) == elements