Esempio n. 1
0
def assert_grads_agree(logits):
    d1 = dist.OneTwoMatching(logits)
    d2 = dist.OneTwoMatching(logits, bp_iters=BP_ITERS)
    expected = torch.autograd.grad(d1.log_partition_function, [logits])[0]
    actual = torch.autograd.grad(d2.log_partition_function, [logits])[0]
    assert torch.allclose(actual, expected, atol=0.2, rtol=1e-3), \
        f"Expected:\n{expected.numpy()}\nActual:\n{actual.numpy()}"
Esempio n. 2
0
def test_mode_full_smoke(num_destins, dtype):
    pytest.importorskip("lap")
    num_sources = 2 * num_destins
    logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10
    d = dist.OneTwoMatching(logits)
    value = d.mode()
    assert d.support.check(value)
Esempio n. 3
0
def test_log_prob_phylo(num_leaves, dtype, bp_iters):
    logits, times = random_phylo_logits(num_leaves, dtype)
    d = dist.OneTwoMatching(logits, bp_iters=bp_iters)
    values = d.enumerate_support()
    log_total = d.log_prob(values).logsumexp(0).item()
    logging.info(f"log_total = {log_total:0.3g}, " +
                 f"log_Z = {d.log_partition_function:0.3g}")
    assert_close(log_total, 0.0, atol=1.0)
Esempio n. 4
0
def test_sample_shape_smoke(num_destins, sample_shape, dtype, bp_iters):
    num_sources = 2 * num_destins
    logits = torch.randn(num_sources, num_destins, dtype=dtype)
    d = dist.OneTwoMatching(logits, bp_iters=bp_iters)
    with xfail_if_not_implemented():
        values = d.sample(sample_shape)
    assert values.shape == sample_shape + (num_sources, )
    assert d.support.check(values).all()
Esempio n. 5
0
def test_log_prob_phylo_smoke(num_leaves, dtype):
    logits, times = random_phylo_logits(num_leaves, dtype)
    d = dist.OneTwoMatching(logits, bp_iters=10)
    logz = d.log_partition_function
    assert logz.dtype == dtype
    assert not torch.isnan(logz)
    dt = torch.autograd.grad(logz, [times])[0]
    assert not torch.isnan(dt).any()
Esempio n. 6
0
def test_log_prob_hard(dtype, bp_iters):
    logits = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, -math.inf]]
    logits = torch.tensor(logits, dtype=dtype)
    d = dist.OneTwoMatching(logits, bp_iters=bp_iters)
    values = d.enumerate_support()
    log_total = d.log_prob(values).logsumexp(0).item()
    logging.info(f"log_total = {log_total:0.3g}, " +
                 f"log_Z = {d.log_partition_function:0.3g}")
    assert_close(log_total, 0.0, atol=0.5)
Esempio n. 7
0
def test_log_prob_full(num_destins, dtype, bp_iters):
    num_sources = 2 * num_destins
    logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10
    d = dist.OneTwoMatching(logits, bp_iters=bp_iters)
    values = d.enumerate_support()
    log_total = d.log_prob(values).logsumexp(0).item()
    logging.info(f"log_total = {log_total:0.3g}, " +
                 f"log_Z = {d.log_partition_function:0.3g}")
    assert_close(log_total, 0.0, atol=1.0)
Esempio n. 8
0
def test_enumerate(num_destins, dtype):
    num_sources = 2 * num_destins
    logits = torch.randn(num_sources, num_destins, dtype=dtype)
    d = dist.OneTwoMatching(logits)
    values = d.enumerate_support()
    logging.info("destins = {}, suport size = {}".format(
        num_destins, len(values)))
    assert d.support.check(values), "invalid"
    assert len(set(map(_hash, values))) == len(values), "not unique"
Esempio n. 9
0
def test_mode_phylo(num_leaves, dtype):
    pytest.importorskip("lap")
    logits, times = random_phylo_logits(num_leaves, dtype)
    d = dist.OneTwoMatching(logits)
    values = d.enumerate_support()
    i = d.log_prob(values).max(0).indices.item()
    expected = values[i]
    actual = d.mode()
    assert_equal(actual, expected)
Esempio n. 10
0
def test_mode_full(num_destins, dtype):
    pytest.importorskip("lap")
    num_sources = 2 * num_destins
    logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10
    d = dist.OneTwoMatching(logits)
    values = d.enumerate_support()
    i = d.log_prob(values).max(0).indices.item()
    expected = values[i]
    actual = d.mode()
    assert_equal(actual, expected)
Esempio n. 11
0
def test_sample_phylo(num_leaves, dtype, bp_iters):
    pytest.importorskip("lap")
    logits, times = random_phylo_logits(num_leaves, dtype)
    num_sources, num_destins = logits.shape
    d = dist.OneTwoMatching(logits, bp_iters=bp_iters)

    # Compute an empirical mean.
    num_samples = 1000
    s = torch.arange(num_sources)
    actual = torch.zeros_like(logits)
    with xfail_if_not_implemented():
        for v in d.sample([num_samples]):
            actual[s, v] += 1 / num_samples

    # Compute truth via enumeration.
    values = d.enumerate_support()
    probs = d.log_prob(values).exp()
    probs /= probs.sum()
    expected = torch.zeros(num_sources, num_destins)
    for v, p in zip(values, probs):
        expected[s, v] += p
    assert_close(actual, expected, atol=0.1)
Esempio n. 12
0
def test_mode_phylo_smoke(num_leaves, dtype):
    pytest.importorskip("lap")
    logits, times = random_phylo_logits(num_leaves, dtype)
    d = dist.OneTwoMatching(logits, bp_iters=10)
    value = d.mode()
    assert d.support.check(value)
Esempio n. 13
0
 def fn(logits):
     d = dist.OneTwoMatching(logits, bp_iters=bp_iters)
     return d.log_partition_function