def assert_grads_agree(logits): d1 = dist.OneTwoMatching(logits) d2 = dist.OneTwoMatching(logits, bp_iters=BP_ITERS) expected = torch.autograd.grad(d1.log_partition_function, [logits])[0] actual = torch.autograd.grad(d2.log_partition_function, [logits])[0] assert torch.allclose(actual, expected, atol=0.2, rtol=1e-3), \ f"Expected:\n{expected.numpy()}\nActual:\n{actual.numpy()}"
def test_mode_full_smoke(num_destins, dtype): pytest.importorskip("lap") num_sources = 2 * num_destins logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10 d = dist.OneTwoMatching(logits) value = d.mode() assert d.support.check(value)
def test_log_prob_phylo(num_leaves, dtype, bp_iters): logits, times = random_phylo_logits(num_leaves, dtype) d = dist.OneTwoMatching(logits, bp_iters=bp_iters) values = d.enumerate_support() log_total = d.log_prob(values).logsumexp(0).item() logging.info(f"log_total = {log_total:0.3g}, " + f"log_Z = {d.log_partition_function:0.3g}") assert_close(log_total, 0.0, atol=1.0)
def test_sample_shape_smoke(num_destins, sample_shape, dtype, bp_iters): num_sources = 2 * num_destins logits = torch.randn(num_sources, num_destins, dtype=dtype) d = dist.OneTwoMatching(logits, bp_iters=bp_iters) with xfail_if_not_implemented(): values = d.sample(sample_shape) assert values.shape == sample_shape + (num_sources, ) assert d.support.check(values).all()
def test_log_prob_phylo_smoke(num_leaves, dtype): logits, times = random_phylo_logits(num_leaves, dtype) d = dist.OneTwoMatching(logits, bp_iters=10) logz = d.log_partition_function assert logz.dtype == dtype assert not torch.isnan(logz) dt = torch.autograd.grad(logz, [times])[0] assert not torch.isnan(dt).any()
def test_log_prob_hard(dtype, bp_iters): logits = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, -math.inf]] logits = torch.tensor(logits, dtype=dtype) d = dist.OneTwoMatching(logits, bp_iters=bp_iters) values = d.enumerate_support() log_total = d.log_prob(values).logsumexp(0).item() logging.info(f"log_total = {log_total:0.3g}, " + f"log_Z = {d.log_partition_function:0.3g}") assert_close(log_total, 0.0, atol=0.5)
def test_log_prob_full(num_destins, dtype, bp_iters): num_sources = 2 * num_destins logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10 d = dist.OneTwoMatching(logits, bp_iters=bp_iters) values = d.enumerate_support() log_total = d.log_prob(values).logsumexp(0).item() logging.info(f"log_total = {log_total:0.3g}, " + f"log_Z = {d.log_partition_function:0.3g}") assert_close(log_total, 0.0, atol=1.0)
def test_enumerate(num_destins, dtype): num_sources = 2 * num_destins logits = torch.randn(num_sources, num_destins, dtype=dtype) d = dist.OneTwoMatching(logits) values = d.enumerate_support() logging.info("destins = {}, suport size = {}".format( num_destins, len(values))) assert d.support.check(values), "invalid" assert len(set(map(_hash, values))) == len(values), "not unique"
def test_mode_phylo(num_leaves, dtype): pytest.importorskip("lap") logits, times = random_phylo_logits(num_leaves, dtype) d = dist.OneTwoMatching(logits) values = d.enumerate_support() i = d.log_prob(values).max(0).indices.item() expected = values[i] actual = d.mode() assert_equal(actual, expected)
def test_mode_full(num_destins, dtype): pytest.importorskip("lap") num_sources = 2 * num_destins logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10 d = dist.OneTwoMatching(logits) values = d.enumerate_support() i = d.log_prob(values).max(0).indices.item() expected = values[i] actual = d.mode() assert_equal(actual, expected)
def test_sample_phylo(num_leaves, dtype, bp_iters): pytest.importorskip("lap") logits, times = random_phylo_logits(num_leaves, dtype) num_sources, num_destins = logits.shape d = dist.OneTwoMatching(logits, bp_iters=bp_iters) # Compute an empirical mean. num_samples = 1000 s = torch.arange(num_sources) actual = torch.zeros_like(logits) with xfail_if_not_implemented(): for v in d.sample([num_samples]): actual[s, v] += 1 / num_samples # Compute truth via enumeration. values = d.enumerate_support() probs = d.log_prob(values).exp() probs /= probs.sum() expected = torch.zeros(num_sources, num_destins) for v, p in zip(values, probs): expected[s, v] += p assert_close(actual, expected, atol=0.1)
def test_mode_phylo_smoke(num_leaves, dtype): pytest.importorskip("lap") logits, times = random_phylo_logits(num_leaves, dtype) d = dist.OneTwoMatching(logits, bp_iters=10) value = d.mode() assert d.support.check(value)
def fn(logits): d = dist.OneTwoMatching(logits, bp_iters=bp_iters) return d.log_partition_function