def test_flat_exact_1_2(e1, e2, a11, a12): exists_logits = torch.tensor([e1, e2]) assign_logits = torch.tensor([[a11, a12]]) expected = MarginalAssignment(exists_logits, assign_logits, None) actual = MarginalAssignment(exists_logits, assign_logits, 10) assert_equal(expected.exists_dist.probs, actual.exists_dist.probs) assert_equal(expected.assign_dist.probs, actual.assign_dist.probs)
def test_flat_exact_2_1(e, a11, a21): exists_logits = torch.tensor([e]) assign_logits = torch.tensor([[a11], [a21]]) expected = MarginalAssignment(exists_logits, assign_logits, None) actual = MarginalAssignment(exists_logits, assign_logits, 10) assert_equal(expected.exists_dist.probs, actual.exists_dist.probs) assert_equal(expected.assign_dist.probs, actual.assign_dist.probs)
def test_flat_bp_vs_exact(num_objects, num_detections): exists_logits = -2 * torch.rand(num_objects) assign_logits = -2 * torch.rand(num_detections, num_objects) expected = MarginalAssignment(exists_logits, assign_logits, None) actual = MarginalAssignment(exists_logits, assign_logits, 10) # these should only approximately agree assert_equal(expected.exists_dist.probs, actual.exists_dist.probs, prec=0.01) assert_equal(expected.assign_dist.probs, actual.assign_dist.probs, prec=0.01)
def test_sparse_smoke(): num_objects = 4 num_detections = 2 pyro.set_rng_seed(0) exists_logits = torch.zeros(num_objects) edges = exists_logits.new_tensor([ [0, 0, 1, 0, 1, 0], [0, 1, 1, 2, 2, 3], ], dtype=torch.long) assign_logits = logit(torch.tensor([0.99, 0.8, 0.2, 0.2, 0.8, 0.9])) assert assign_logits.shape == edges.shape[1:] solver = MarginalAssignmentSparse(num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters=5) assert solver.exists_dist.batch_shape == (num_objects, ) assert solver.exists_dist.event_shape == () assert solver.assign_dist.batch_shape == (num_detections, ) assert solver.assign_dist.event_shape == () assert solver.assign_dist.probs.shape[ -1] == num_objects + 1 # true + spurious # test dense matches sparse assign_logits = sparse_to_dense(num_objects, num_detections, edges, assign_logits) other = MarginalAssignment(exists_logits, assign_logits, bp_iters=5) assert_equal(other.exists_dist.probs, solver.exists_dist.probs, prec=1e-3) assert_equal(other.assign_dist.probs, solver.assign_dist.probs, prec=1e-3)
def test_dense_smoke(): num_objects = 4 num_detections = 2 pyro.set_rng_seed(0) exists_logits = torch.zeros(num_objects) assign_logits = logit( torch.tensor([ [0.5, 0.5, 0.0, 0.0], [0.0, 0.5, 0.5, 0.5], ])) assert assign_logits.shape == (num_detections, num_objects) solver = MarginalAssignment(exists_logits, assign_logits, bp_iters=5) assert solver.exists_dist.batch_shape == (num_objects, ) assert solver.exists_dist.event_shape == () assert solver.assign_dist.batch_shape == (num_detections, ) assert solver.assign_dist.event_shape == () assert solver.assign_dist.probs.shape[ -1] == num_objects + 1 # true + spurious # test dense matches sparse edges, assign_logits = dense_to_sparse(assign_logits) other = MarginalAssignmentSparse(num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters=5) assert_equal(other.exists_dist.probs, solver.exists_dist.probs, prec=1e-3) assert_equal(other.assign_dist.probs, solver.assign_dist.probs, prec=1e-3)
def guide(detections, args): noise_scale = pyro.param('noise_scale') # trained by SVI objects = pyro.param('objects_loc').squeeze(-1) # trained by M-step of EM num_detections, = detections.shape max_num_objects, = objects.shape with torch.set_grad_enabled(args.assignment_grad): # Evaluate log likelihoods. TODO make this more pyronic. exists_logits = compute_exists_logits(objects, args) assign_logits = compute_assign_logits(objects, detections.unsqueeze(-1), noise_scale, args) assert exists_logits.shape == (max_num_objects, ) assert assign_logits.shape == (num_detections, max_num_objects) # Compute soft assignments. assignment = MarginalAssignment(exists_logits, assign_logits, bp_iters=10) with pyro.plate('objects_plate', max_num_objects): pyro.sample('exists', assignment.exists_dist, infer={'enumerate': 'parallel'}) with pyro.plate('detections_plate', num_detections): pyro.sample('assign', assignment.assign_dist, infer={'enumerate': 'parallel'})
def guide(detections, args): noise_scale = pyro.param("noise_scale") # trained by SVI objects = pyro.param("objects_loc").squeeze(-1) # trained by M-step of EM (num_detections, ) = detections.shape (max_num_objects, ) = objects.shape with torch.set_grad_enabled(args.assignment_grad): # Evaluate log likelihoods. TODO make this more pyronic. exists_logits = compute_exists_logits(objects, args) assign_logits = compute_assign_logits(objects, detections.unsqueeze(-1), noise_scale, args) assert exists_logits.shape == (max_num_objects, ) assert assign_logits.shape == (num_detections, max_num_objects) # Compute soft assignments. assignment = MarginalAssignment(exists_logits, assign_logits, bp_iters=10) with pyro.plate("objects_plate", max_num_objects): pyro.sample("exists", assignment.exists_dist, infer={"enumerate": "parallel"}) with pyro.plate("detections_plate", num_detections): pyro.sample("assign", assignment.assign_dist, infer={"enumerate": "parallel"})
def test_flat_vs_persistent(num_objects, num_frames, bp_iters): exists_logits = -2 * torch.rand(num_objects) assign_logits = -2 * torch.rand(num_frames, num_objects) flat = MarginalAssignment(exists_logits, assign_logits, bp_iters) full = MarginalAssignmentPersistent(exists_logits, assign_logits.unsqueeze(1), bp_iters) assert_equal(flat.exists_dist.probs, full.exists_dist.probs) assert_equal(flat.assign_dist.probs, full.assign_dist.probs.squeeze(1))