예제 #1
0
def test_iter_bfs():

    tree = BinaryBranch(BinaryBranch(Leaf(no=4), Leaf(no=5), no=2),
                        Leaf(no=3),
                        no=1)

    for i, node in enumerate(tree.iter_bfs(), start=1):
        assert i == node.no
예제 #2
0
def test_iter_branches():

    tree = BinaryBranch(
        BinaryBranch(BinaryBranch(Leaf(), Leaf(), no=3), Leaf(), no=2),
        BinaryBranch(Leaf(), Leaf(), no=4),
        no=1,
    )

    for i, branch in enumerate(tree.iter_branches(), start=1):
        assert i == branch.no
예제 #3
0
def test_iter_edges():

    tree = BinaryBranch(BinaryBranch(Leaf(no=3), Leaf(no=4), no=2),
                        Leaf(no=5),
                        no=1)

    order = [(1, 2), (2, 3), (2, 4), (1, 5)]

    for i, (parent, child) in enumerate(tree.iter_edges()):
        assert order[i] == (parent.no, child.no)
예제 #4
0
파일: hst.py 프로젝트: mipo57/river
def make_padded_tree(limits, height, padding, rng=random, **node_params):

    if height == 0:
        return Leaf(**node_params)

    # Randomly pick a feature
    # We weight each feature by the gap between each feature's limits
    on = rng.choices(
        population=list(limits.keys()),
        weights=[limits[i][1] - limits[i][0] for i in limits],
    )[0]

    # Pick a split point; use padding to avoid too narrow a split
    a = limits[on][0]
    b = limits[on][1]
    at = rng.uniform(a + padding * (b - a), b - padding * (b - a))

    # Build the left node
    tmp = limits[on]
    limits[on] = (tmp[0], at)
    left = make_padded_tree(
        limits=limits, height=height - 1, padding=padding, rng=rng, **node_params
    )
    limits[on] = tmp

    # Build the right node
    tmp = limits[on]
    limits[on] = (at, tmp[1])
    right = make_padded_tree(
        limits=limits, height=height - 1, padding=padding, rng=rng, **node_params
    )
    limits[on] = tmp

    return HSTBranch(left=left, right=right, feature=on, threshold=at, **node_params)
예제 #5
0
def test_height():

    tree = BinaryBranch(
        BinaryBranch(
            BinaryBranch(
                BinaryBranch(Leaf(), Leaf()),
                Leaf(),
            ),
            BinaryBranch(Leaf(), Leaf()),
        ),
        BinaryBranch(Leaf(), Leaf()),
    )

    assert tree.height == 5
    assert tree.children[0].height == 4
    assert tree.children[1].height == 2
    assert tree.children[1].children[0].height == 1
예제 #6
0
def test_size():

    tree = BinaryBranch(
        BinaryBranch(
            BinaryBranch(BinaryBranch(Leaf(), Leaf()), Leaf()),
            BinaryBranch(Leaf(), Leaf()),
        ),
        BinaryBranch(Leaf(), Leaf()),
    )

    assert tree.n_nodes == tree.n_branches + tree.n_leaves == 6 + 7
    assert (tree.children[0].n_nodes ==
            tree.children[0].n_branches + tree.children[0].n_leaves == 4 + 5)
    assert (tree.children[1].n_nodes ==
            tree.children[1].n_branches + tree.children[1].n_leaves == 1 + 2)
    assert (tree.children[1].children[0].n_nodes ==
            tree.children[1].children[0].n_branches +
            tree.children[1].children[0].n_leaves == 0 + 1)
예제 #7
0
def test_iter_leaves():

    tree = BinaryBranch(BinaryBranch(Leaf(no=1), Leaf(no=2)), Leaf(no=3))

    for i, leaf in enumerate(tree.iter_leaves(), start=1):
        assert i == leaf.no