def test_zero3d(): """Tree with 3 nodes and zero input""" t = TreeInstance(y=np.array([0., 0., 3.]), mu=np.array([1., 0., 1.]), lam=np.array([0.5, 1.0, np.nan]), parent=np.array([1, 2, 2], dtype=np.int32), root=2) t.solve() assert np.allclose(t.x, [0.5, 2.5, 2.5])
def test_nonzero(): """ Similar to test_zero4 but with default mu == 1 """ t = TreeInstance(y=np.array([0., 0., 0., 2.]), mu=np.array([1., 1., 1., 1.]), lam=np.array([1.0, 0.3, np.nan, 1.0]), parent=np.array([1, 2, 2, 2], dtype=np.int32), root=2) t.solve() assert np.allclose(t.x, [0.15, 0.15, 0.7, 1.0])
def test_zero3b(): """Tree with 3 nodes and zero input""" t = TreeInstance(y=np.array([0., 0., 3.]), mu=np.array([1., 0., 0.]), lam=np.array([1.0, 0.5, 0.3]), parent=np.array([1, 2, 2], dtype=np.int32), root=2) t.solve() assert np.allclose(t.x, 0) alpha = t.dual assert np.where(np.isnan(alpha))[0].tolist() == [t.root], \ f'alpha={alpha}, root={t.root}' g = t.gamma assert (g > -1e-10).all(), f'gamma={g}' assert (g < +1e-10).all(), f'gamma={g}' v = t.dual_diff assert (v > -1e-10).all(), f'v={v}\nx={t.x}\nalpha={t.dual}\nlam={t.lam}'
def test_demo_3x7(): y = np.fromstring("0.62 0.73 0.71 1.5 1.17 0.43 1.08 0.62 " + "1.73 0.95 1.46 1.6 1.16 0.38 0.9 0.32 " + "-0.48 0.95 1.08 0.02 0.4", sep=" ") parent = np.array([ 0, 4, 5, 0, 3, 4, 7, 8, 5, 6, 7, 8, 9, 14, 17, 12, 15, 16, 19, 16, 17 ]) lam = 1.0 prob = TreeInstance(y, parent, lam=lam) assert prob.root == 0 assert prob.parent.dtype == np.int32 prob.solve() assert abs(prob.x.mean() - prob.y.mean()) < 1e-15 assert len(np.unique(prob.x)) == 2 assert max(np.abs(prob.dual[2:]) - lam) < 1e-12 assert max(np.abs(prob.gamma)) < 1e-15
def test_gap_x0(t, gap=1e-10): t = TreeInstance.load(test_dir(t)) x = t.solve().x alpha = t.dual assert np.isnan(alpha[t.root]) assert not np.isnan(alpha[:-1]).any() assert all(np.abs(alpha[:-1]) <= t.lam[:-1] + gap) assert min(t.gamma) >= -1e-12 assert max(t.gamma) < 1e-10
def test_gamma_line3(eps=1e-14): ti = TreeInstance(root=2, y=np.array([0.8, -0.6, 0.]), mu=1.0, lam=0.2, parent=np.array([2, 0, 2], dtype=np.int32)) po = post_order(ti.parent, include_root=True) assert (po == [1, 0, 2]).all() ti.solve() assert ((ti.x - ti.y)[po] / ti.lam).astype(int).tolist() == [1, -2, 1] # alpha = ti.dual assert np.isnan(ti.dual[ti.root]) assert ti.gamma.min() >= -eps, ti.gamma.min() assert ti.gamma.max() <= +eps, f'gamma = {ti.gamma.max()}\n{ti}'
def test_rtree(n=5, seed=2018, eps=1e-14): """Random tree creation""" t = Tree.random(n, seed=seed) assert t.n == n y = np.array([0.1, 1.7, -0.1, 1., 1.1]) if n != 5 or seed != 2015: np.random.seed(seed) y = np.random.normal(size=n).round(1) lam = 0.2 ti = TreeInstance(y, t.parent, lam=lam) ti.solve() assert abs(ti.x.mean() - ti.y.mean()) < eps if n == 5 and seed == 2015: assert (ti.x == [0.1, 1.5, 0.1, 1.05, 1.05]).all() assert ti.gamma.min() >= -eps, ti.gamma.min() assert ti.gamma.max( ) <= +eps, f'seed = {seed}: gamma = {ti.gamma.max()}\n{ti}'
def test_gamma3(): """Call TreeInstance.gamma on a simple tree""" t = TreeInstance(y=np.array([0., 13, 0.75]), mu=np.array([1., 0.1, 0.001]), lam=np.array([1.0, 0.5, 0.3]), parent=np.array([1, 2, 2], dtype=np.int32), root=2) g = t.gamma assert (g >= 0).all(), f'gamma={g}' assert (g < 1e-10).all(), f'gamma={g}'
def test_nan(): """`y` contains `NaN`""" t = TreeInstance(y=np.array([0., np.nan, np.nan]), mu=np.array([1., 0., 0.]), lam=np.array([1.0, 0.5, 0.3]), parent=np.array([1, 2, 2], dtype=np.int32), root=2) if __asan__: return with pytest.raises(RuntimeError) as e: t.solve() assert 'y[1] = nan' in str(e.value) t.y = np.array([0., 0, 0]) t.mu = np.array([1., 0., np.nan]) with pytest.raises(RuntimeError) as e: t.solve() assert 'mu[2] = nan' in str(e.value)
def tree5(request): # 0 1 2 3 4 5 6 7 8 9 parent = [0, 0, 1, 2, 3, 0, 7, 8, 3, 8] y = [8.2, 7.0, 9.5, 6.8, 5.8, 6.3, 4.3, 2.2, 1.2, 2.8] lam = 1.0 t = Tree(parent) assert t.root == 0 def cleaner(): if os.getenv("KEEP_TREE5"): return if os.path.exists('tree5.toml'): os.remove('tree5.toml') if os.path.exists('tree5.h5'): os.remove('tree5.h5') request.addfinalizer(cleaner) return TreeInstance(y, t.parent, lam=lam)
def test_zero4(): """ Tree with 4 nodes, actually a line graph where just the two end nodes contain information """ t = TreeInstance(y=np.array([0., 0., 0., 2.]), mu=np.array([1., 0., 0., 1.]), lam=np.array([1.0, 0.3, np.nan, 1.0]), parent=np.array([1, 2, 2, 2], dtype=np.int32), root=2) t.solve() assert np.allclose(t.x, [0.3, 0.3, 1.7, 1.7]) if False: t.show(wait=False) alpha = t.dual assert np.where(np.isnan(alpha))[0].tolist() == [t.root], \ f'alpha={alpha}, root={t.root}' g = t.gamma assert (g > -1e-10).all(), f'gamma={g}' assert (g < +1e-10).all(), f'gamma={g}' v = t.dual_diff assert (v > -1e-10).all(), f'v={v}\nx={t.x}\nalpha={t.dual}\nlam={t.lam}'
def tree1(): return TreeInstance.load(test_dir("tree0.1.toml"))
def test_tree5_write_h5_read(tree5): pytest.importorskip("toml") ti = tree5 ti.save('tree5.toml') t2 = TreeInstance.load('tree5.toml') assert repr(ti) == repr(t2), f"\n\n{ti}\n\n{t2}\n"
def test_tree5_write_h5_read(tree5): ti = tree5 ti.save('tree5.h5') t2 = TreeInstance.load('tree5.h5') assert repr(ti) == repr(t2), f"\n\n{ti}\n\n{t2}\n"