Пример #1
0
 def test_ess_sampling(self, dtype, device, lap, M):
     try:
         k = 2
         g = rand_udg(N, dtype=dtype, device=device)
         L = laplace(g, lap_type=lap)
         S = ess(L, M, k)
         print(S)
     except ArpackNoConvergence:
         print("No convergence error")
Пример #2
0
def test_computing_sets(dtype):
    g = rand_udg(N, 0.1, dtype=dtype)
    sets, lengths = computing_sets(g, T=8e-5, p_hops=3)
    print(sets)
    print(lengths)

    sets, lengths = computing_sets(g.set_value_(None), T=8e-5, p_hops=3)
    print(sets)
    print(lengths)
Пример #3
0
    def test_recon(self, dtype, device, lap, M):
        g = rand_udg(N, dtype=dtype, device=device)
        L = laplace(g, lap_type=lap)
        S = ess(L, M)

        num_sig = 10
        f = th.rand(N, num_sig, dtype=dtype, device=device)
        fs = f[S, :]
        f_hat = recon_ess(fs, S, g.U(lap), bd=int(3 * N / 4))
        snr_and_mse(f_hat, f)
Пример #4
0
def test_recon_bsgda():
    mu = 0.01
    g = rand_udg(N, 0.1, dtype=torch.double)
    sampled_nodes, thresh = bsgda(g, K * 8, mu=mu, epsilon=1e-8)

    print("sampled nodes: ", sampled_nodes)
    print("thresh(T):     ", thresh)

    f = torch.randn(N, 100, dtype=torch.double)
    f_hat = recon_bsgda(f[sampled_nodes], sampled_nodes, g.L("comb"), mu)
    m, s = snr_and_mse(f_hat, f)
Пример #5
0
def test_consistency_greedy_sampling(dtype):
    g = rand_udg(N, 0.3, dtype=dtype).set_value_(None)
    selected_pebbles, vf = greedy_sampling(g, K, T, p_hops=max_hops)

    sets, lengths = computing_sets(g, T, p_hops=max_hops)
    s1, vf1 = solving_set_covering(sets, lengths, K)

    s2, vf2 = greedy_gda_sampling(g, K, T, p_hops=max_hops)

    print(selected_pebbles)
    print(s1)
    print(s2)
    print(vf)
    print(vf1)
    print(vf2)

    assert s1 == selected_pebbles
    assert vf1 == vf

    assert s1 == s2
    assert vf1 == vf2
Пример #6
0
def test_bsgda(dtype):
    g = rand_udg(N, 0.1, dtype=dtype).fill_value_(1.0)
    sampled_nodes, thresh = bsgda(g, K * 2)
    print("sampled nodes: ", sampled_nodes)
    print("thresh(T):     ", thresh)