コード例 #1
0
ファイル: test_spanning_tree.py プロジェクト: zeta1999/pyro
def test_enumerate_support(num_edges):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits)
    with xfail_if_not_implemented():
        support = d.enumerate_support()
    assert support.dim() == 3
    assert support.shape[1:] == d.event_shape
    assert support.size(0) == NUM_SPANNING_TREES[V]
コード例 #2
0
ファイル: test_spanning_tree.py プロジェクト: zeta1999/pyro
def test_sample_tree_gof(method, backend, num_edges, pattern):
    goftests = pytest.importorskip('goftests')
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2

    if pattern == "uniform":
        edge_logits = torch.zeros(K)
        num_samples = 10 * NUM_SPANNING_TREES[V]
    elif pattern == "random":
        edge_logits = torch.rand(K)
        num_samples = 30 * NUM_SPANNING_TREES[V]
    elif pattern == "sparse":
        edge_logits = torch.rand(K)
        for v2 in range(V):
            for v1 in range(v2):
                if v1 + 1 < v2:
                    edge_logits[v1 + v2 * (v2 - 1) // 2] = -float('inf')
        num_samples = 10 * NUM_SPANNING_TREES[V]

    # Generate many samples.
    counts = Counter()
    tensors = {}
    # Initialize using approximate sampler, to ensure feasibility.
    edges = sample_tree(edge_logits, backend=backend)
    for _ in range(num_samples):
        if method == "approx":
            # Reset the chain with an approximate sample, then perform 1 step of mcmc.
            edges = sample_tree(edge_logits, backend=backend)
        edges = sample_tree(edge_logits, edges, backend=backend)
        key = tuple((v1.item(), v2.item()) for v1, v2 in edges)
        counts[key] += 1
        tensors[key] = edges
    if pattern != "sparse":
        assert len(counts) == NUM_SPANNING_TREES[V]

    # Check accuracy using a Pearson's chi-squared test.
    keys = [k for k, _ in counts.most_common(100)]
    truncated = (len(keys) < len(counts))
    counts = torch.tensor([counts[k] for k in keys])
    tensors = torch.stack([tensors[k] for k in keys])
    probs = SpanningTree(edge_logits).log_prob(tensors).exp()
    gof = goftests.multinomial_goodness_of_fit(probs.numpy(),
                                               counts.numpy(),
                                               num_samples,
                                               plot=True,
                                               truncated=truncated)
    logging.info('gof = {}'.format(gof))
    if method == "approx":
        assert gof >= 0.0001
    else:
        assert gof >= 0.005
コード例 #3
0
ファイル: test_spanning_tree.py プロジェクト: zeta1999/pyro
def test_log_prob(num_edges):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits)
    with xfail_if_not_implemented():
        support = d.enumerate_support()
    log_probs = d.log_prob(support)
    assert log_probs.shape == (len(support), )
    log_total = log_probs.logsumexp(0).item()
    assert abs(log_total) < 1e-6, log_total
コード例 #4
0
ファイル: test_spanning_tree.py プロジェクト: zeta1999/pyro
def test_partition_function(num_edges):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits)
    with xfail_if_not_implemented():
        support = d.enumerate_support()
    v1 = support[..., 0]
    v2 = support[..., 1]
    k = v1 + v2 * (v2 - 1) // 2
    expected = edge_logits[k].sum(-1).logsumexp(0)
    actual = d.log_partition_function
    assert (actual - expected).abs() < 1e-6, (actual, expected)
コード例 #5
0
def test_mode(num_edges, backend):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits, sampler_options={"backend": backend})
    with xfail_if_not_implemented():
        support = d.enumerate_support()
    v1 = support[..., 0]
    v2 = support[..., 1]
    k = v1 + v2 * (v2 - 1) // 2
    expected = support[edge_logits[k].sum(-1).argmax(0)]
    actual = d.mode
    assert (actual == expected).all()
コード例 #6
0
def test_edge_mean_function(num_edges):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits)

    with xfail_if_not_implemented():
        support = d.enumerate_support()
    v1 = support[..., 0]
    v2 = support[..., 1]
    k = v1 + v2 * (v2 - 1) // 2
    probs = d.log_prob(support).exp()[:, None].expand_as(k)
    expected = torch.zeros(K).scatter_add_(0, k.reshape(-1), probs.reshape(-1))

    actual = d.edge_mean
    assert actual.shape == (V, V)
    v1, v2 = make_complete_graph(V)
    assert (actual[v1, v2] - expected).abs().max() < 1e-5, (actual, expected)