def test_log_prob(num_edges): pyro.set_rng_seed(2**32 - num_edges) E = num_edges V = 1 + E K = V * (V - 1) // 2 edge_logits = torch.randn(K) d = SpanningTree(edge_logits) with xfail_if_not_implemented(): support = d.enumerate_support() log_probs = d.log_prob(support) assert log_probs.shape == (len(support), ) log_total = log_probs.logsumexp(0).item() assert abs(log_total) < 1e-6, log_total
def test_edge_mean_function(num_edges): pyro.set_rng_seed(2**32 - num_edges) E = num_edges V = 1 + E K = V * (V - 1) // 2 edge_logits = torch.randn(K) d = SpanningTree(edge_logits) with xfail_if_not_implemented(): support = d.enumerate_support() v1 = support[..., 0] v2 = support[..., 1] k = v1 + v2 * (v2 - 1) // 2 probs = d.log_prob(support).exp()[:, None].expand_as(k) expected = torch.zeros(K).scatter_add_(0, k.reshape(-1), probs.reshape(-1)) actual = d.edge_mean assert actual.shape == (V, V) v1, v2 = make_complete_graph(V) assert (actual[v1, v2] - expected).abs().max() < 1e-5, (actual, expected)