Esempio n. 1
0
def test_persistent_independent_subproblems(num_objects, num_frames, num_detections, bp_iters):
    # solve a random assignment problem
    exists_logits_1 = -2 * torch.rand(num_objects)
    assign_logits_1 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1
    assignment_1 = MarginalAssignmentPersistent(exists_logits_1, assign_logits_1, bp_iters)
    exists_probs_1 = assignment_1.exists_dist.probs
    assign_probs_1 = assignment_1.assign_dist.probs

    # solve another random assignment problem
    exists_logits_2 = -2 * torch.rand(num_objects)
    assign_logits_2 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1
    assignment_2 = MarginalAssignmentPersistent(exists_logits_2, assign_logits_2, bp_iters)
    exists_probs_2 = assignment_2.exists_dist.probs
    assign_probs_2 = assignment_2.assign_dist.probs

    # solve a unioned assignment problem
    exists_logits = torch.cat([exists_logits_1, exists_logits_2])
    assign_logits = torch.full((num_frames, num_detections * 2, num_objects * 2), -INF)
    assign_logits[:, :num_detections, :num_objects] = assign_logits_1
    assign_logits[:, num_detections:, num_objects:] = assign_logits_2
    assignment = MarginalAssignmentPersistent(exists_logits, assign_logits, bp_iters)
    exists_probs = assignment.exists_dist.probs
    assign_probs = assignment.assign_dist.probs

    # check agreement
    assert_equal(exists_probs_1, exists_probs[:num_objects])
    assert_equal(exists_probs_2, exists_probs[num_objects:])
    assert_equal(assign_probs_1[:, :, :-1], assign_probs[:, :num_detections, :num_objects])
    assert_equal(assign_probs_1[:, :, -1], assign_probs[:, :num_detections, -1])
    assert_equal(assign_probs_2[:, :, :-1], assign_probs[:, num_detections:, num_objects:-1])
    assert_equal(assign_probs_2[:, :, -1], assign_probs[:, num_detections:, -1])
Esempio n. 2
0
def test_persistent_bp_vs_exact(num_objects, num_frames, num_detections):
    exists_logits = -2 * torch.rand(num_objects)
    assign_logits = 2 * torch.rand(num_frames, num_detections, num_objects) - 1
    expected = MarginalAssignmentPersistent(exists_logits, assign_logits, None)
    actual = MarginalAssignmentPersistent(exists_logits, assign_logits, 30)
    # these should only approximately agree
    assert_equal(expected.exists_dist.probs, actual.exists_dist.probs, prec=0.05)
    assert_equal(expected.assign_dist.probs, actual.assign_dist.probs, prec=0.05)
Esempio n. 3
0
def test_persistent_smoke(bp_iters):
    exists_logits = torch.tensor([-1., -1., -2.], requires_grad=True)
    assign_logits = torch.tensor(
        [[[-1., -INF, -INF], [-2., -2., -INF]],
         [[-1., -2., -3.], [-2., -2., -1.]],
         [[-1., -2., -3.], [-2., -2., -1.]], [[-1., -1., 1.], [1., 1., -1.]]],
        requires_grad=True)

    assignment = MarginalAssignmentPersistent(exists_logits,
                                              assign_logits,
                                              bp_iters=bp_iters)
    assert assignment.num_frames == 4
    assert assignment.num_detections == 2
    assert assignment.num_objects == 3

    assign_dist = assignment.assign_dist
    exists_dist = assignment.exists_dist
    assert_finite(exists_dist.probs, 'exists_probs')
    assert_finite(assign_dist.probs, 'assign_probs')

    for exists in exists_dist.enumerate_support():
        log_prob = exists_dist.log_prob(exists).sum()
        e_grad, a_grad = grad(log_prob, [exists_logits, assign_logits],
                              create_graph=True)
        assert_finite(e_grad, 'dexists_probs/dexists_logits')
        assert_finite(a_grad, 'dexists_probs/dassign_logits')

    for assign in assign_dist.enumerate_support():
        log_prob = assign_dist.log_prob(assign).sum()
        e_grad, a_grad = grad(log_prob, [exists_logits, assign_logits],
                              create_graph=True)
        assert_finite(e_grad, 'dassign_probs/dexists_logits')
        assert_finite(a_grad, 'dassign_probs/dassign_logits')
Esempio n. 4
0
def test_persistent_exact_5_4_3(e1, e2, e3, bp_iters, bp_momentum):
    exists_logits = torch.tensor([e1, e2, e3])
    assign_logits = 2 * torch.rand(5, 4, 3) - 1
    # this has tree-shaped connectivity and should lead to exact inference
    mask = torch.tensor([[[1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]],
                         [[1, 0, 0], [0, 1, 1], [0, 0, 1], [1, 0, 0]],
                         [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]],
                         [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 1, 0]],
                         [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]]], dtype=torch.bool)
    assign_logits[~mask] = -INF
    expected = MarginalAssignmentPersistent(exists_logits, assign_logits, None)
    actual = MarginalAssignmentPersistent(exists_logits, assign_logits, bp_iters, bp_momentum)
    assert_equal(expected.exists_dist.probs, actual.exists_dist.probs)
    assert_equal(expected.assign_dist.probs, actual.assign_dist.probs)
    logger.debug(actual.exists_dist.probs)
    logger.debug(actual.assign_dist.probs)
Esempio n. 5
0
def test_flat_vs_persistent(num_objects, num_frames, bp_iters):
    exists_logits = -2 * torch.rand(num_objects)
    assign_logits = -2 * torch.rand(num_frames, num_objects)
    flat = MarginalAssignment(exists_logits, assign_logits, bp_iters)
    full = MarginalAssignmentPersistent(exists_logits, assign_logits.unsqueeze(1), bp_iters)
    assert_equal(flat.exists_dist.probs, full.exists_dist.probs)
    assert_equal(flat.assign_dist.probs, full.assign_dist.probs.squeeze(1))