コード例 #1
0
ファイル: test_spanning_tree.py プロジェクト: zeta1999/pyro
def test_log_prob(num_edges):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits)
    with xfail_if_not_implemented():
        support = d.enumerate_support()
    log_probs = d.log_prob(support)
    assert log_probs.shape == (len(support), )
    log_total = log_probs.logsumexp(0).item()
    assert abs(log_total) < 1e-6, log_total
コード例 #2
0
def test_edge_mean_function(num_edges):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits)

    with xfail_if_not_implemented():
        support = d.enumerate_support()
    v1 = support[..., 0]
    v2 = support[..., 1]
    k = v1 + v2 * (v2 - 1) // 2
    probs = d.log_prob(support).exp()[:, None].expand_as(k)
    expected = torch.zeros(K).scatter_add_(0, k.reshape(-1), probs.reshape(-1))

    actual = d.edge_mean
    assert actual.shape == (V, V)
    v1, v2 = make_complete_graph(V)
    assert (actual[v1, v2] - expected).abs().max() < 1e-5, (actual, expected)