def test_zero3c(): """Tree with 3 nodes and zero input""" t = TreeInstance(y=np.array([0., 0., 3.]), mu=np.array([1., 0., 1.]), lam=np.array([1.0, 0.5, 0.3]), parent=np.array([1, 2, 2], dtype=np.int32), root=2) t.solve() assert np.allclose(t.x, [0.5, 0.5, 2.5]) alpha = t.dual assert np.where(np.isnan(alpha))[0].tolist() == [t.root], \ f'alpha={alpha}, root={t.root}' g = t.gamma assert (g > -1e-10).all(), f'gamma={g}' assert (g < +1e-10).all(), f'gamma={g}' v = t.dual_diff if False: t.show(wait=False) assert (v > -1e-10).all(), f'v={v}\nx={t.x}\nalpha={t.dual}\nlam={t.lam}'
def test_zero4(): """ Tree with 4 nodes, actually a line graph where just the two end nodes contain information """ t = TreeInstance(y=np.array([0., 0., 0., 2.]), mu=np.array([1., 0., 0., 1.]), lam=np.array([1.0, 0.3, np.nan, 1.0]), parent=np.array([1, 2, 2, 2], dtype=np.int32), root=2) t.solve() assert np.allclose(t.x, [0.3, 0.3, 1.7, 1.7]) if False: t.show(wait=False) alpha = t.dual assert np.where(np.isnan(alpha))[0].tolist() == [t.root], \ f'alpha={alpha}, root={t.root}' g = t.gamma assert (g > -1e-10).all(), f'gamma={g}' assert (g < +1e-10).all(), f'gamma={g}' v = t.dual_diff assert (v > -1e-10).all(), f'v={v}\nx={t.x}\nalpha={t.dual}\nlam={t.lam}'