def test_persistent_independent_subproblems(num_objects, num_frames, num_detections, bp_iters): # solve a random assignment problem exists_logits_1 = -2 * torch.rand(num_objects) assign_logits_1 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1 assignment_1 = MarginalAssignmentPersistent(exists_logits_1, assign_logits_1, bp_iters) exists_probs_1 = assignment_1.exists_dist.probs assign_probs_1 = assignment_1.assign_dist.probs # solve another random assignment problem exists_logits_2 = -2 * torch.rand(num_objects) assign_logits_2 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1 assignment_2 = MarginalAssignmentPersistent(exists_logits_2, assign_logits_2, bp_iters) exists_probs_2 = assignment_2.exists_dist.probs assign_probs_2 = assignment_2.assign_dist.probs # solve a unioned assignment problem exists_logits = torch.cat([exists_logits_1, exists_logits_2]) assign_logits = torch.full((num_frames, num_detections * 2, num_objects * 2), -INF) assign_logits[:, :num_detections, :num_objects] = assign_logits_1 assign_logits[:, num_detections:, num_objects:] = assign_logits_2 assignment = MarginalAssignmentPersistent(exists_logits, assign_logits, bp_iters) exists_probs = assignment.exists_dist.probs assign_probs = assignment.assign_dist.probs # check agreement assert_equal(exists_probs_1, exists_probs[:num_objects]) assert_equal(exists_probs_2, exists_probs[num_objects:]) assert_equal(assign_probs_1[:, :, :-1], assign_probs[:, :num_detections, :num_objects]) assert_equal(assign_probs_1[:, :, -1], assign_probs[:, :num_detections, -1]) assert_equal(assign_probs_2[:, :, :-1], assign_probs[:, num_detections:, num_objects:-1]) assert_equal(assign_probs_2[:, :, -1], assign_probs[:, num_detections:, -1])
def test_persistent_bp_vs_exact(num_objects, num_frames, num_detections): exists_logits = -2 * torch.rand(num_objects) assign_logits = 2 * torch.rand(num_frames, num_detections, num_objects) - 1 expected = MarginalAssignmentPersistent(exists_logits, assign_logits, None) actual = MarginalAssignmentPersistent(exists_logits, assign_logits, 30) # these should only approximately agree assert_equal(expected.exists_dist.probs, actual.exists_dist.probs, prec=0.05) assert_equal(expected.assign_dist.probs, actual.assign_dist.probs, prec=0.05)
def test_persistent_smoke(bp_iters): exists_logits = torch.tensor([-1., -1., -2.], requires_grad=True) assign_logits = torch.tensor( [[[-1., -INF, -INF], [-2., -2., -INF]], [[-1., -2., -3.], [-2., -2., -1.]], [[-1., -2., -3.], [-2., -2., -1.]], [[-1., -1., 1.], [1., 1., -1.]]], requires_grad=True) assignment = MarginalAssignmentPersistent(exists_logits, assign_logits, bp_iters=bp_iters) assert assignment.num_frames == 4 assert assignment.num_detections == 2 assert assignment.num_objects == 3 assign_dist = assignment.assign_dist exists_dist = assignment.exists_dist assert_finite(exists_dist.probs, 'exists_probs') assert_finite(assign_dist.probs, 'assign_probs') for exists in exists_dist.enumerate_support(): log_prob = exists_dist.log_prob(exists).sum() e_grad, a_grad = grad(log_prob, [exists_logits, assign_logits], create_graph=True) assert_finite(e_grad, 'dexists_probs/dexists_logits') assert_finite(a_grad, 'dexists_probs/dassign_logits') for assign in assign_dist.enumerate_support(): log_prob = assign_dist.log_prob(assign).sum() e_grad, a_grad = grad(log_prob, [exists_logits, assign_logits], create_graph=True) assert_finite(e_grad, 'dassign_probs/dexists_logits') assert_finite(a_grad, 'dassign_probs/dassign_logits')
def test_persistent_exact_5_4_3(e1, e2, e3, bp_iters, bp_momentum): exists_logits = torch.tensor([e1, e2, e3]) assign_logits = 2 * torch.rand(5, 4, 3) - 1 # this has tree-shaped connectivity and should lead to exact inference mask = torch.tensor([[[1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]], [[1, 0, 0], [0, 1, 1], [0, 0, 1], [1, 0, 0]], [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]], [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]]], dtype=torch.bool) assign_logits[~mask] = -INF expected = MarginalAssignmentPersistent(exists_logits, assign_logits, None) actual = MarginalAssignmentPersistent(exists_logits, assign_logits, bp_iters, bp_momentum) assert_equal(expected.exists_dist.probs, actual.exists_dist.probs) assert_equal(expected.assign_dist.probs, actual.assign_dist.probs) logger.debug(actual.exists_dist.probs) logger.debug(actual.assign_dist.probs)
def test_flat_vs_persistent(num_objects, num_frames, bp_iters): exists_logits = -2 * torch.rand(num_objects) assign_logits = -2 * torch.rand(num_frames, num_objects) flat = MarginalAssignment(exists_logits, assign_logits, bp_iters) full = MarginalAssignmentPersistent(exists_logits, assign_logits.unsqueeze(1), bp_iters) assert_equal(flat.exists_dist.probs, full.exists_dist.probs) assert_equal(flat.assign_dist.probs, full.assign_dist.probs.squeeze(1))