コード例 #1
0
def test_mini_h5_w():
    y, parent, xt, lam, root = load_mini_h5()
    n = len(y)
    mua = np.ones(n)
    lama = lam * np.ones(n)
    x2 = tl.tree_dp(y=y, parent=parent, lam=lama, mu=mua, root=root)
    assert (x2.reshape(xt.shape) == xt).all()
コード例 #2
0
def test_mini_h5():
    y, parent, xt, lam, root = load_mini_h5()
    x = tl.tree_dp(y, parent, lam, root)
    assert (x.reshape(xt.shape) == xt).all()
    assert sys.getrefcount(parent) <= 2
    assert sys.getrefcount(y) <= 2
    assert sys.getrefcount(x) <= 2
コード例 #3
0
def average_tree(fname: str, lam: float, rep: int):
    with h5py.File(fname) as io:
        y = io["y"][()]

    grid_opt = fname.replace(".img", f"_lam{lam}.grid_opt")
    if path.exists(grid_opt):
        with h5py.File(grid_opt) as io:
            print(grid_opt)
            grid_opt = io["xgrid"][()]
    else:
        grid_opt = None
    graph = BiAdjacent(GridGraph(*y.shape))
    yvec = y.reshape(-1, order='F')
    xsol = list()
    for s in range(rep):
        print("seed =", s)
        pi = random_spanning_tree(graph, seed=s)
        assert len(pi) == len(yvec)
        assert pi.max() < len(pi)
        assert pi.min() >= 0
        x = treelas.tree_dp(y=yvec, parent=pi, lam=8 * lam, verbose=False)
        xsol.append(x.reshape(*y.shape, order='F'))

    xopt = np.mean(xsol, axis=0)
    return xsol, xopt, grid_opt
コード例 #4
0
ファイル: test_randomspan.py プロジェクト: EQt/treelas
def test_grid_2x2():
    graph = BiAdjacent(GridGraph(2, 2))
    pi = random_spanning_tree(graph, seed=0)
    if os.getenv("show"):
        Tree(pi).show()
    np.random.seed(0)
    y = np.random.randn(len(pi))
    x = tree_dp(y=y, parent=pi, lam=0.2)
コード例 #5
0
ファイル: test_randomspan.py プロジェクト: EQt/treelas
def test_grid_3x3():
    graph = BiAdjacent(GridGraph(3, 3))
    pi = random_spanning_tree(graph, seed=0)
    if os.getenv("show"):
        Tree(pi).show()
    assert all(pi == [3, 0, 1, 6, 7, 8, 6, 6, 7])
    np.random.seed(0)
    y = np.random.randn(len(pi))
    x = tree_dp(y=y, parent=pi, lam=0.2)
コード例 #6
0
ファイル: test_randomspan.py プロジェクト: EQt/treelas
def test_grid(width):
    graph = BiAdjacent(GridGraph(width, width))
    pi = random_spanning_tree(graph, seed=0)
    np.random.seed(0)
    y = np.random.randn(len(pi))
    x = tree_dp(y=y, parent=pi, lam=0.2)