def test_iter_bfs(): tree = BinaryBranch(BinaryBranch(Leaf(no=4), Leaf(no=5), no=2), Leaf(no=3), no=1) for i, node in enumerate(tree.iter_bfs(), start=1): assert i == node.no
def test_iter_branches(): tree = BinaryBranch( BinaryBranch(BinaryBranch(Leaf(), Leaf(), no=3), Leaf(), no=2), BinaryBranch(Leaf(), Leaf(), no=4), no=1, ) for i, branch in enumerate(tree.iter_branches(), start=1): assert i == branch.no
def test_iter_edges(): tree = BinaryBranch(BinaryBranch(Leaf(no=3), Leaf(no=4), no=2), Leaf(no=5), no=1) order = [(1, 2), (2, 3), (2, 4), (1, 5)] for i, (parent, child) in enumerate(tree.iter_edges()): assert order[i] == (parent.no, child.no)
def make_padded_tree(limits, height, padding, rng=random, **node_params): if height == 0: return Leaf(**node_params) # Randomly pick a feature # We weight each feature by the gap between each feature's limits on = rng.choices( population=list(limits.keys()), weights=[limits[i][1] - limits[i][0] for i in limits], )[0] # Pick a split point; use padding to avoid too narrow a split a = limits[on][0] b = limits[on][1] at = rng.uniform(a + padding * (b - a), b - padding * (b - a)) # Build the left node tmp = limits[on] limits[on] = (tmp[0], at) left = make_padded_tree( limits=limits, height=height - 1, padding=padding, rng=rng, **node_params ) limits[on] = tmp # Build the right node tmp = limits[on] limits[on] = (at, tmp[1]) right = make_padded_tree( limits=limits, height=height - 1, padding=padding, rng=rng, **node_params ) limits[on] = tmp return HSTBranch(left=left, right=right, feature=on, threshold=at, **node_params)
def test_height(): tree = BinaryBranch( BinaryBranch( BinaryBranch( BinaryBranch(Leaf(), Leaf()), Leaf(), ), BinaryBranch(Leaf(), Leaf()), ), BinaryBranch(Leaf(), Leaf()), ) assert tree.height == 5 assert tree.children[0].height == 4 assert tree.children[1].height == 2 assert tree.children[1].children[0].height == 1
def test_size(): tree = BinaryBranch( BinaryBranch( BinaryBranch(BinaryBranch(Leaf(), Leaf()), Leaf()), BinaryBranch(Leaf(), Leaf()), ), BinaryBranch(Leaf(), Leaf()), ) assert tree.n_nodes == tree.n_branches + tree.n_leaves == 6 + 7 assert (tree.children[0].n_nodes == tree.children[0].n_branches + tree.children[0].n_leaves == 4 + 5) assert (tree.children[1].n_nodes == tree.children[1].n_branches + tree.children[1].n_leaves == 1 + 2) assert (tree.children[1].children[0].n_nodes == tree.children[1].children[0].n_branches + tree.children[1].children[0].n_leaves == 0 + 1)
def test_iter_leaves(): tree = BinaryBranch(BinaryBranch(Leaf(no=1), Leaf(no=2)), Leaf(no=3)) for i, leaf in enumerate(tree.iter_leaves(), start=1): assert i == leaf.no