示例#1
0
def test_sample_tree_approx_smoke(num_edges, backend):
    pyro.set_rng_seed(num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.rand(K)
    for _ in range(10 if backend == "cpp" or num_edges <= 30 else 1):
        sample_tree(edge_logits, backend=backend)
示例#2
0
def test_sample_tree_gof(method, backend, num_edges, pattern):
    goftests = pytest.importorskip('goftests')
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2

    if pattern == "uniform":
        edge_logits = torch.zeros(K)
        num_samples = 10 * NUM_SPANNING_TREES[V]
    elif pattern == "random":
        edge_logits = torch.rand(K)
        num_samples = 30 * NUM_SPANNING_TREES[V]
    elif pattern == "sparse":
        edge_logits = torch.rand(K)
        for v2 in range(V):
            for v1 in range(v2):
                if v1 + 1 < v2:
                    edge_logits[v1 + v2 * (v2 - 1) // 2] = -float('inf')
        num_samples = 10 * NUM_SPANNING_TREES[V]

    # Generate many samples.
    counts = Counter()
    tensors = {}
    # Initialize using approximate sampler, to ensure feasibility.
    edges = sample_tree(edge_logits, backend=backend)
    for _ in range(num_samples):
        if method == "approx":
            # Reset the chain with an approximate sample, then perform 1 step of mcmc.
            edges = sample_tree(edge_logits, backend=backend)
        edges = sample_tree(edge_logits, edges, backend=backend)
        key = tuple((v1.item(), v2.item()) for v1, v2 in edges)
        counts[key] += 1
        tensors[key] = edges
    if pattern != "sparse":
        assert len(counts) == NUM_SPANNING_TREES[V]

    # Check accuracy using a Pearson's chi-squared test.
    keys = [k for k, _ in counts.most_common(100)]
    truncated = (len(keys) < len(counts))
    counts = torch.tensor([counts[k] for k in keys])
    tensors = torch.stack([tensors[k] for k in keys])
    probs = SpanningTree(edge_logits).log_prob(tensors).exp()
    gof = goftests.multinomial_goodness_of_fit(probs.numpy(),
                                               counts.numpy(),
                                               num_samples,
                                               plot=True,
                                               truncated=truncated)
    logging.info('gof = {}'.format(gof))
    if method == "approx":
        assert gof >= 0.0001
    else:
        assert gof >= 0.005
示例#3
0
def test_sample_tree_mcmc_smoke(num_edges, backend):
    pyro.set_rng_seed(num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.rand(K)
    edges = torch.tensor([(v, v + 1) for v in range(V - 1)], dtype=torch.long)
    for _ in range(10 if backend == "cpp" or num_edges <= 30 else 1):
        edges = sample_tree(edge_logits, edges, backend=backend)