Beispiel #1
0
def test_flat_exact_1_2(e1, e2, a11, a12):
    exists_logits = torch.tensor([e1, e2])
    assign_logits = torch.tensor([[a11, a12]])
    expected = MarginalAssignment(exists_logits, assign_logits, None)
    actual = MarginalAssignment(exists_logits, assign_logits, 10)
    assert_equal(expected.exists_dist.probs, actual.exists_dist.probs)
    assert_equal(expected.assign_dist.probs, actual.assign_dist.probs)
Beispiel #2
0
def test_flat_exact_2_1(e, a11, a21):
    exists_logits = torch.tensor([e])
    assign_logits = torch.tensor([[a11], [a21]])
    expected = MarginalAssignment(exists_logits, assign_logits, None)
    actual = MarginalAssignment(exists_logits, assign_logits, 10)
    assert_equal(expected.exists_dist.probs, actual.exists_dist.probs)
    assert_equal(expected.assign_dist.probs, actual.assign_dist.probs)
Beispiel #3
0
def test_flat_bp_vs_exact(num_objects, num_detections):
    exists_logits = -2 * torch.rand(num_objects)
    assign_logits = -2 * torch.rand(num_detections, num_objects)
    expected = MarginalAssignment(exists_logits, assign_logits, None)
    actual = MarginalAssignment(exists_logits, assign_logits, 10)
    # these should only approximately agree
    assert_equal(expected.exists_dist.probs, actual.exists_dist.probs, prec=0.01)
    assert_equal(expected.assign_dist.probs, actual.assign_dist.probs, prec=0.01)
Beispiel #4
0
def test_sparse_smoke():
    num_objects = 4
    num_detections = 2
    pyro.set_rng_seed(0)
    exists_logits = torch.zeros(num_objects)
    edges = exists_logits.new_tensor([
        [0, 0, 1, 0, 1, 0],
        [0, 1, 1, 2, 2, 3],
    ],
                                     dtype=torch.long)
    assign_logits = logit(torch.tensor([0.99, 0.8, 0.2, 0.2, 0.8, 0.9]))
    assert assign_logits.shape == edges.shape[1:]

    solver = MarginalAssignmentSparse(num_objects,
                                      num_detections,
                                      edges,
                                      exists_logits,
                                      assign_logits,
                                      bp_iters=5)

    assert solver.exists_dist.batch_shape == (num_objects, )
    assert solver.exists_dist.event_shape == ()
    assert solver.assign_dist.batch_shape == (num_detections, )
    assert solver.assign_dist.event_shape == ()
    assert solver.assign_dist.probs.shape[
        -1] == num_objects + 1  # true + spurious

    # test dense matches sparse
    assign_logits = sparse_to_dense(num_objects, num_detections, edges,
                                    assign_logits)
    other = MarginalAssignment(exists_logits, assign_logits, bp_iters=5)
    assert_equal(other.exists_dist.probs, solver.exists_dist.probs, prec=1e-3)
    assert_equal(other.assign_dist.probs, solver.assign_dist.probs, prec=1e-3)
Beispiel #5
0
def test_dense_smoke():
    num_objects = 4
    num_detections = 2
    pyro.set_rng_seed(0)
    exists_logits = torch.zeros(num_objects)
    assign_logits = logit(
        torch.tensor([
            [0.5, 0.5, 0.0, 0.0],
            [0.0, 0.5, 0.5, 0.5],
        ]))
    assert assign_logits.shape == (num_detections, num_objects)

    solver = MarginalAssignment(exists_logits, assign_logits, bp_iters=5)

    assert solver.exists_dist.batch_shape == (num_objects, )
    assert solver.exists_dist.event_shape == ()
    assert solver.assign_dist.batch_shape == (num_detections, )
    assert solver.assign_dist.event_shape == ()
    assert solver.assign_dist.probs.shape[
        -1] == num_objects + 1  # true + spurious

    # test dense matches sparse
    edges, assign_logits = dense_to_sparse(assign_logits)
    other = MarginalAssignmentSparse(num_objects,
                                     num_detections,
                                     edges,
                                     exists_logits,
                                     assign_logits,
                                     bp_iters=5)
    assert_equal(other.exists_dist.probs, solver.exists_dist.probs, prec=1e-3)
    assert_equal(other.assign_dist.probs, solver.assign_dist.probs, prec=1e-3)
Beispiel #6
0
def guide(detections, args):
    noise_scale = pyro.param('noise_scale')  # trained by SVI
    objects = pyro.param('objects_loc').squeeze(-1)  # trained by M-step of EM
    num_detections, = detections.shape
    max_num_objects, = objects.shape

    with torch.set_grad_enabled(args.assignment_grad):
        # Evaluate log likelihoods. TODO make this more pyronic.
        exists_logits = compute_exists_logits(objects, args)
        assign_logits = compute_assign_logits(objects,
                                              detections.unsqueeze(-1),
                                              noise_scale, args)
        assert exists_logits.shape == (max_num_objects, )
        assert assign_logits.shape == (num_detections, max_num_objects)

        # Compute soft assignments.
        assignment = MarginalAssignment(exists_logits,
                                        assign_logits,
                                        bp_iters=10)

    with pyro.plate('objects_plate', max_num_objects):
        pyro.sample('exists',
                    assignment.exists_dist,
                    infer={'enumerate': 'parallel'})
    with pyro.plate('detections_plate', num_detections):
        pyro.sample('assign',
                    assignment.assign_dist,
                    infer={'enumerate': 'parallel'})
Beispiel #7
0
def guide(detections, args):
    noise_scale = pyro.param("noise_scale")  # trained by SVI
    objects = pyro.param("objects_loc").squeeze(-1)  # trained by M-step of EM
    (num_detections, ) = detections.shape
    (max_num_objects, ) = objects.shape

    with torch.set_grad_enabled(args.assignment_grad):
        # Evaluate log likelihoods. TODO make this more pyronic.
        exists_logits = compute_exists_logits(objects, args)
        assign_logits = compute_assign_logits(objects,
                                              detections.unsqueeze(-1),
                                              noise_scale, args)
        assert exists_logits.shape == (max_num_objects, )
        assert assign_logits.shape == (num_detections, max_num_objects)

        # Compute soft assignments.
        assignment = MarginalAssignment(exists_logits,
                                        assign_logits,
                                        bp_iters=10)

    with pyro.plate("objects_plate", max_num_objects):
        pyro.sample("exists",
                    assignment.exists_dist,
                    infer={"enumerate": "parallel"})
    with pyro.plate("detections_plate", num_detections):
        pyro.sample("assign",
                    assignment.assign_dist,
                    infer={"enumerate": "parallel"})
Beispiel #8
0
def test_flat_vs_persistent(num_objects, num_frames, bp_iters):
    exists_logits = -2 * torch.rand(num_objects)
    assign_logits = -2 * torch.rand(num_frames, num_objects)
    flat = MarginalAssignment(exists_logits, assign_logits, bp_iters)
    full = MarginalAssignmentPersistent(exists_logits, assign_logits.unsqueeze(1), bp_iters)
    assert_equal(flat.exists_dist.probs, full.exists_dist.probs)
    assert_equal(flat.assign_dist.probs, full.assign_dist.probs.squeeze(1))