Ejemplo n.º 1
0
def assert_grads_agree(logits):
    d1 = dist.OneOneMatching(logits)
    d2 = dist.OneOneMatching(logits, bp_iters=BP_ITERS)
    expected = torch.autograd.grad(d1.log_partition_function, [logits])[0]
    actual = torch.autograd.grad(d2.log_partition_function, [logits])[0]
    assert torch.allclose(actual, expected, atol=0.2, rtol=1e-3), \
        f"Expected:\n{expected.numpy()}\nActual:\n{actual.numpy()}"
Ejemplo n.º 2
0
def test_sample_shape_smoke(num_nodes, sample_shape, dtype, bp_iters):
    logits = torch.randn(num_nodes, num_nodes, dtype=dtype)
    d = dist.OneOneMatching(logits, bp_iters=bp_iters)
    with xfail_if_not_implemented():
        values = d.sample(sample_shape)
    assert values.shape == sample_shape + (num_nodes, )
    assert d.support.check(values).all()
Ejemplo n.º 3
0
def test_log_prob_full(num_nodes, dtype, bp_iters):
    logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
    d = dist.OneOneMatching(logits, bp_iters=bp_iters)
    values = d.enumerate_support()
    log_total = d.log_prob(values).logsumexp(0).item()
    logging.info(f"log_total = {log_total:0.3g}, " +
                 f"log_Z = {d.log_partition_function:0.3g}")
    assert_close(log_total, 0.0, atol=2.0)
Ejemplo n.º 4
0
def test_enumerate(num_nodes, dtype):
    logits = torch.randn(num_nodes, num_nodes, dtype=dtype)
    d = dist.OneOneMatching(logits)
    values = d.enumerate_support()
    logging.info("destins = {}, suport size = {}".format(
        num_nodes, len(values)))
    assert d.support.check(values), "invalid"
    assert len(set(map(_hash, values))) == len(values), "not unique"
Ejemplo n.º 5
0
def test_log_prob_hard(dtype, bp_iters):
    logits = [[0.0, 0.0], [0.0, -math.inf]]
    logits = torch.tensor(logits, dtype=dtype)
    d = dist.OneOneMatching(logits, bp_iters=bp_iters)
    values = d.enumerate_support()
    log_total = d.log_prob(values).logsumexp(0).item()
    logging.info(f"log_total = {log_total:0.3g}, " +
                 f"log_Z = {d.log_partition_function:0.3g}")
    assert_close(log_total, 0.0, atol=0.5)
Ejemplo n.º 6
0
def test_mode(num_nodes, dtype):
    pytest.importorskip("lap")
    logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
    d = dist.OneOneMatching(logits)
    values = d.enumerate_support()
    i = d.log_prob(values).max(0).indices.item()
    expected = values[i]
    actual = d.mode()
    assert_equal(actual, expected)
Ejemplo n.º 7
0
def test_sample(num_nodes, dtype, bp_iters):
    pytest.importorskip("lap")
    logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
    d = dist.OneOneMatching(logits, bp_iters=bp_iters)

    # Compute an empirical mean.
    num_samples = 1000
    s = torch.arange(num_nodes)
    actual = torch.zeros_like(logits)
    with xfail_if_not_implemented():
        for v in d.sample([num_samples]):
            actual[s, v] += 1 / num_samples

    # Compute truth via enumeration.
    values = d.enumerate_support()
    probs = d.log_prob(values).exp()
    probs /= probs.sum()
    expected = torch.zeros(num_nodes, num_nodes)
    for v, p in zip(values, probs):
        expected[s, v] += p
    assert_close(actual, expected, atol=0.1)
Ejemplo n.º 8
0
 def fn(logits):
     d = dist.OneOneMatching(logits, bp_iters=bp_iters)
     return d.log_partition_function
Ejemplo n.º 9
0
def test_mode_smoke(num_nodes, dtype):
    pytest.importorskip("lap")
    logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
    d = dist.OneOneMatching(logits)
    value = d.mode()
    assert d.support.check(value)