예제 #1
0
def test_zero3d():
    """Tree with 3 nodes and zero input"""
    t = TreeInstance(y=np.array([0., 0., 3.]),
                     mu=np.array([1., 0., 1.]),
                     lam=np.array([0.5, 1.0, np.nan]),
                     parent=np.array([1, 2, 2], dtype=np.int32),
                     root=2)
    t.solve()
    assert np.allclose(t.x, [0.5, 2.5, 2.5])
예제 #2
0
def test_nonzero():
    """
    Similar to test_zero4 but with default mu == 1
    """
    t = TreeInstance(y=np.array([0., 0., 0., 2.]),
                     mu=np.array([1., 1., 1., 1.]),
                     lam=np.array([1.0, 0.3, np.nan, 1.0]),
                     parent=np.array([1, 2, 2, 2], dtype=np.int32),
                     root=2)
    t.solve()
    assert np.allclose(t.x, [0.15, 0.15, 0.7, 1.0])
예제 #3
0
def test_zero3b():
    """Tree with 3 nodes and zero input"""
    t = TreeInstance(y=np.array([0., 0., 3.]),
                     mu=np.array([1., 0., 0.]),
                     lam=np.array([1.0, 0.5, 0.3]),
                     parent=np.array([1, 2, 2], dtype=np.int32),
                     root=2)
    t.solve()
    assert np.allclose(t.x, 0)
    alpha = t.dual
    assert np.where(np.isnan(alpha))[0].tolist() == [t.root], \
        f'alpha={alpha}, root={t.root}'
    g = t.gamma
    assert (g > -1e-10).all(), f'gamma={g}'
    assert (g < +1e-10).all(), f'gamma={g}'
    v = t.dual_diff
    assert (v > -1e-10).all(), f'v={v}\nx={t.x}\nalpha={t.dual}\nlam={t.lam}'
예제 #4
0
def test_demo_3x7():
    y = np.fromstring("0.62 0.73 0.71 1.5 1.17 0.43 1.08 0.62 " +
                      "1.73 0.95 1.46 1.6 1.16 0.38 0.9 0.32 " +
                      "-0.48 0.95 1.08 0.02 0.4",
                      sep=" ")
    parent = np.array([
        0, 4, 5, 0, 3, 4, 7, 8, 5, 6, 7, 8, 9, 14, 17, 12, 15, 16, 19, 16, 17
    ])
    lam = 1.0
    prob = TreeInstance(y, parent, lam=lam)
    assert prob.root == 0
    assert prob.parent.dtype == np.int32
    prob.solve()
    assert abs(prob.x.mean() - prob.y.mean()) < 1e-15
    assert len(np.unique(prob.x)) == 2
    assert max(np.abs(prob.dual[2:]) - lam) < 1e-12
    assert max(np.abs(prob.gamma)) < 1e-15
예제 #5
0
def test_gap_x0(t, gap=1e-10):
    t = TreeInstance.load(test_dir(t))
    x = t.solve().x
    alpha = t.dual
    assert np.isnan(alpha[t.root])
    assert not np.isnan(alpha[:-1]).any()
    assert all(np.abs(alpha[:-1]) <= t.lam[:-1] + gap)
    assert min(t.gamma) >= -1e-12
    assert max(t.gamma) < 1e-10
예제 #6
0
def test_gamma_line3(eps=1e-14):
    ti = TreeInstance(root=2,
                      y=np.array([0.8, -0.6, 0.]),
                      mu=1.0,
                      lam=0.2,
                      parent=np.array([2, 0, 2], dtype=np.int32))

    po = post_order(ti.parent, include_root=True)
    assert (po == [1, 0, 2]).all()

    ti.solve()
    assert ((ti.x - ti.y)[po] / ti.lam).astype(int).tolist() == [1, -2, 1]

    # alpha = ti.dual
    assert np.isnan(ti.dual[ti.root])

    assert ti.gamma.min() >= -eps, ti.gamma.min()
    assert ti.gamma.max() <= +eps, f'gamma = {ti.gamma.max()}\n{ti}'
예제 #7
0
파일: test_tree.py 프로젝트: EQt/treelas
def test_rtree(n=5, seed=2018, eps=1e-14):
    """Random tree creation"""
    t = Tree.random(n, seed=seed)
    assert t.n == n
    y = np.array([0.1, 1.7, -0.1, 1., 1.1])
    if n != 5 or seed != 2015:
        np.random.seed(seed)
        y = np.random.normal(size=n).round(1)
    lam = 0.2
    ti = TreeInstance(y, t.parent, lam=lam)
    ti.solve()

    assert abs(ti.x.mean() - ti.y.mean()) < eps
    if n == 5 and seed == 2015:
        assert (ti.x == [0.1, 1.5, 0.1, 1.05, 1.05]).all()
    assert ti.gamma.min() >= -eps, ti.gamma.min()
    assert ti.gamma.max(
    ) <= +eps, f'seed = {seed}: gamma = {ti.gamma.max()}\n{ti}'
예제 #8
0
def test_gamma3():
    """Call TreeInstance.gamma on a simple tree"""
    t = TreeInstance(y=np.array([0., 13, 0.75]),
                     mu=np.array([1., 0.1, 0.001]),
                     lam=np.array([1.0, 0.5, 0.3]),
                     parent=np.array([1, 2, 2], dtype=np.int32),
                     root=2)
    g = t.gamma
    assert (g >= 0).all(), f'gamma={g}'
    assert (g < 1e-10).all(), f'gamma={g}'
예제 #9
0
def test_nan():
    """`y` contains `NaN`"""
    t = TreeInstance(y=np.array([0., np.nan, np.nan]),
                     mu=np.array([1., 0., 0.]),
                     lam=np.array([1.0, 0.5, 0.3]),
                     parent=np.array([1, 2, 2], dtype=np.int32),
                     root=2)
    if __asan__:
        return
    with pytest.raises(RuntimeError) as e:
        t.solve()
        assert 'y[1] = nan' in str(e.value)
    t.y = np.array([0., 0, 0])
    t.mu = np.array([1., 0., np.nan])
    with pytest.raises(RuntimeError) as e:
        t.solve()
        assert 'mu[2] = nan' in str(e.value)
예제 #10
0
파일: test_tree.py 프로젝트: EQt/treelas
def tree5(request):
    #         0  1  2  3  4  5  6  7  8  9
    parent = [0, 0, 1, 2, 3, 0, 7, 8, 3, 8]
    y = [8.2, 7.0, 9.5, 6.8, 5.8, 6.3, 4.3, 2.2, 1.2, 2.8]
    lam = 1.0
    t = Tree(parent)
    assert t.root == 0

    def cleaner():
        if os.getenv("KEEP_TREE5"):
            return
        if os.path.exists('tree5.toml'):
            os.remove('tree5.toml')
        if os.path.exists('tree5.h5'):
            os.remove('tree5.h5')

    request.addfinalizer(cleaner)
    return TreeInstance(y, t.parent, lam=lam)
예제 #11
0
def test_zero4():
    """
    Tree with 4 nodes, actually a line graph where just the two end
    nodes contain information
    """
    t = TreeInstance(y=np.array([0., 0., 0., 2.]),
                     mu=np.array([1., 0., 0., 1.]),
                     lam=np.array([1.0, 0.3, np.nan, 1.0]),
                     parent=np.array([1, 2, 2, 2], dtype=np.int32),
                     root=2)
    t.solve()
    assert np.allclose(t.x, [0.3, 0.3, 1.7, 1.7])
    if False:
        t.show(wait=False)
    alpha = t.dual
    assert np.where(np.isnan(alpha))[0].tolist() == [t.root], \
        f'alpha={alpha}, root={t.root}'
    g = t.gamma
    assert (g > -1e-10).all(), f'gamma={g}'
    assert (g < +1e-10).all(), f'gamma={g}'
    v = t.dual_diff
    assert (v > -1e-10).all(), f'v={v}\nx={t.x}\nalpha={t.dual}\nlam={t.lam}'
예제 #12
0
def tree1():
    return TreeInstance.load(test_dir("tree0.1.toml"))
예제 #13
0
파일: test_tree.py 프로젝트: EQt/treelas
def test_tree5_write_h5_read(tree5):
    pytest.importorskip("toml")
    ti = tree5
    ti.save('tree5.toml')
    t2 = TreeInstance.load('tree5.toml')
    assert repr(ti) == repr(t2), f"\n\n{ti}\n\n{t2}\n"
예제 #14
0
파일: test_tree.py 프로젝트: EQt/treelas
def test_tree5_write_h5_read(tree5):
    ti = tree5
    ti.save('tree5.h5')
    t2 = TreeInstance.load('tree5.h5')
    assert repr(ti) == repr(t2), f"\n\n{ti}\n\n{t2}\n"