Esempio n. 1
0
 def test_uncondition(self):
     unconditioned_model = poutine.uncondition(self.model)
     unconditioned_trace = poutine.trace(unconditioned_model).get_trace()
     conditioned_trace = poutine.trace(self.model).get_trace()
     assert_equal(conditioned_trace.nodes["obs"]["value"], torch.ones(2))
     assert_not_equal(unconditioned_trace.nodes["obs"]["value"],
                      torch.ones(2))
Esempio n. 2
0
def test_EKFState_with_NcvContinuous():
    d = 6
    ncv = NcvContinuous(dimension=d, sa2=2.0)
    x = torch.rand(d)
    P = torch.eye(d)
    t = 0.0
    dt = 2.0
    ekf_state = EKFState(
        dynamic_model=ncv, mean=x, cov=P, time=t)

    assert ekf_state.dynamic_model.__class__ == NcvContinuous
    assert ekf_state.dimension == d
    assert ekf_state.dimension_pv == d

    assert_equal(x, ekf_state.mean, prec=1e-5)
    assert_equal(P, ekf_state.cov, prec=1e-5)
    assert_equal(x, ekf_state.mean_pv, prec=1e-5)
    assert_equal(P, ekf_state.cov_pv, prec=1e-5)
    assert_equal(t, ekf_state.time, prec=1e-5)

    ekf_state1 = EKFState(ncv, 2*x, 2*P, t)
    ekf_state2 = ekf_state1.predict(dt)
    assert ekf_state2.dynamic_model.__class__ == NcvContinuous

    measurement = PositionMeasurement(
        mean=torch.rand(d),
        cov=torch.eye(d),
        time=t + dt)
    log_likelihood = ekf_state2.log_likelihood_of_update(measurement)
    assert (log_likelihood < 0.).all()
    ekf_state3, (dz, S) = ekf_state2.update(measurement)
    assert dz.shape == (measurement.dimension,)
    assert S.shape == (measurement.dimension, measurement.dimension)
    assert_not_equal(ekf_state3.mean, ekf_state2.mean, prec=1e-5)
Esempio n. 3
0
def test_NcvContinuous():
    framerate = 100  # Hz
    dt = 1.0 / framerate
    d = 6
    ncv = NcvContinuous(dimension=d, sa2=2.0)
    assert ncv.dimension == d
    assert ncv.dimension_pv == d
    assert ncv.num_process_noise_parameters == 1

    x = torch.rand(d)
    y = ncv(x, dt)
    assert_equal(y[0], x[0] + dt * x[d // 2])

    dx = ncv.geodesic_difference(x, y)
    assert_not_equal(dx, torch.zeros(d))

    x_pv = ncv.mean2pv(x)
    assert len(x_pv) == d
    assert_equal(x, x_pv)

    P = torch.eye(d)
    P_pv = ncv.cov2pv(P)
    assert P_pv.shape == (d, d)
    assert_equal(P, P_pv)

    Q = ncv.process_noise_cov(dt)
    Q1 = ncv.process_noise_cov(dt)  # Test caching.
    assert_equal(Q, Q1)
    assert Q1.shape == (d, d)
    assert_cov_validity(Q1)

    dx = ncv.process_noise_dist(dt).sample()
    assert dx.shape == (ncv.dimension, )
Esempio n. 4
0
def test_csis_parameter_update():
    pyro.clear_param_store()
    guide = Guide()
    initial_parameters = {k: v.item() for k, v in guide.named_parameters()}
    csis = pyro.infer.CSIS(model, guide, pyro.optim.Adam({"lr": 1e-2}))
    csis.step()
    updated_parameters = {k: v.item() for k, v in guide.named_parameters()}
    for k, init_v in initial_parameters.items():
        assert_not_equal(init_v, updated_parameters[k])
Esempio n. 5
0
def test_csis_validation_batch():
    pyro.clear_param_store()
    guide = Guide()
    csis = pyro.infer.CSIS(model,
                           guide,
                           pyro.optim.Adam({}),
                           validation_batch_size=5)
    init_loss_1 = csis.validation_loss()
    init_loss_2 = csis.validation_loss()
    csis.step()
    next_loss = csis.validation_loss()
    assert_equal(init_loss_1, init_loss_2)
    assert_not_equal(init_loss_1, next_loss)
Esempio n. 6
0
def test_counterfactual_query(intervene, observe, flip):
    # x -> y -> z -> w

    sites = ["x", "y", "z", "w"]
    observations = {"x": 1., "y": None, "z": 1., "w": 1.}
    interventions = {"x": None, "y": 0., "z": 2., "w": 1.}

    def model():
        x = _item(pyro.sample("x", dist.Normal(0, 1)))
        y = _item(pyro.sample("y", dist.Normal(x, 1)))
        z = _item(pyro.sample("z", dist.Normal(y, 1)))
        w = _item(pyro.sample("w", dist.Normal(z, 1)))
        return dict(x=x, y=y, z=z, w=w)

    if not flip:
        if intervene:
            model = poutine.do(model, data=interventions)
        if observe:
            model = poutine.condition(model, data=observations)
    elif flip and intervene and observe:
        model = poutine.do(poutine.condition(model, data=observations),
                           data=interventions)

    tr = poutine.trace(model).get_trace()
    actual_values = tr.nodes["_RETURN"]["value"]
    for name in sites:
        # case 1: purely observational query like poutine.condition
        if not intervene and observe:
            if observations[name] is not None:
                assert tr.nodes[name]['is_observed']
                assert_equal(observations[name], actual_values[name])
                assert_equal(observations[name], tr.nodes[name]['value'])
            if interventions[name] != observations[name]:
                assert_not_equal(interventions[name], actual_values[name])
        # case 2: purely interventional query like old poutine.do
        elif intervene and not observe:
            assert not tr.nodes[name]['is_observed']
            if interventions[name] is not None:
                assert_equal(interventions[name], actual_values[name])
            assert_not_equal(observations[name], tr.nodes[name]['value'])
            assert_not_equal(interventions[name], tr.nodes[name]['value'])
        # case 3: counterfactual query mixing intervention and observation
        elif intervene and observe:
            if observations[name] is not None:
                assert tr.nodes[name]['is_observed']
                assert_equal(observations[name], tr.nodes[name]['value'])
            if interventions[name] is not None:
                assert_equal(interventions[name], actual_values[name])
            if interventions[name] != observations[name]:
                assert_not_equal(interventions[name], tr.nodes[name]['value'])