def test_sample_tree_approx_smoke(num_edges, backend): pyro.set_rng_seed(num_edges) E = num_edges V = 1 + E K = V * (V - 1) // 2 edge_logits = torch.rand(K) for _ in range(10 if backend == "cpp" or num_edges <= 30 else 1): sample_tree(edge_logits, backend=backend)
def test_sample_tree_gof(method, backend, num_edges, pattern): goftests = pytest.importorskip('goftests') pyro.set_rng_seed(2**32 - num_edges) E = num_edges V = 1 + E K = V * (V - 1) // 2 if pattern == "uniform": edge_logits = torch.zeros(K) num_samples = 10 * NUM_SPANNING_TREES[V] elif pattern == "random": edge_logits = torch.rand(K) num_samples = 30 * NUM_SPANNING_TREES[V] elif pattern == "sparse": edge_logits = torch.rand(K) for v2 in range(V): for v1 in range(v2): if v1 + 1 < v2: edge_logits[v1 + v2 * (v2 - 1) // 2] = -float('inf') num_samples = 10 * NUM_SPANNING_TREES[V] # Generate many samples. counts = Counter() tensors = {} # Initialize using approximate sampler, to ensure feasibility. edges = sample_tree(edge_logits, backend=backend) for _ in range(num_samples): if method == "approx": # Reset the chain with an approximate sample, then perform 1 step of mcmc. edges = sample_tree(edge_logits, backend=backend) edges = sample_tree(edge_logits, edges, backend=backend) key = tuple((v1.item(), v2.item()) for v1, v2 in edges) counts[key] += 1 tensors[key] = edges if pattern != "sparse": assert len(counts) == NUM_SPANNING_TREES[V] # Check accuracy using a Pearson's chi-squared test. keys = [k for k, _ in counts.most_common(100)] truncated = (len(keys) < len(counts)) counts = torch.tensor([counts[k] for k in keys]) tensors = torch.stack([tensors[k] for k in keys]) probs = SpanningTree(edge_logits).log_prob(tensors).exp() gof = goftests.multinomial_goodness_of_fit(probs.numpy(), counts.numpy(), num_samples, plot=True, truncated=truncated) logging.info('gof = {}'.format(gof)) if method == "approx": assert gof >= 0.0001 else: assert gof >= 0.005
def test_sample_tree_mcmc_smoke(num_edges, backend): pyro.set_rng_seed(num_edges) E = num_edges V = 1 + E K = V * (V - 1) // 2 edge_logits = torch.rand(K) edges = torch.tensor([(v, v + 1) for v in range(V - 1)], dtype=torch.long) for _ in range(10 if backend == "cpp" or num_edges <= 30 else 1): edges = sample_tree(edge_logits, edges, backend=backend)