def test_uncondition(self): unconditioned_model = poutine.uncondition(self.model) unconditioned_trace = poutine.trace(unconditioned_model).get_trace() conditioned_trace = poutine.trace(self.model).get_trace() assert_equal(conditioned_trace.nodes["obs"]["value"], torch.ones(2)) assert_not_equal(unconditioned_trace.nodes["obs"]["value"], torch.ones(2))
def test_EKFState_with_NcvContinuous(): d = 6 ncv = NcvContinuous(dimension=d, sa2=2.0) x = torch.rand(d) P = torch.eye(d) t = 0.0 dt = 2.0 ekf_state = EKFState( dynamic_model=ncv, mean=x, cov=P, time=t) assert ekf_state.dynamic_model.__class__ == NcvContinuous assert ekf_state.dimension == d assert ekf_state.dimension_pv == d assert_equal(x, ekf_state.mean, prec=1e-5) assert_equal(P, ekf_state.cov, prec=1e-5) assert_equal(x, ekf_state.mean_pv, prec=1e-5) assert_equal(P, ekf_state.cov_pv, prec=1e-5) assert_equal(t, ekf_state.time, prec=1e-5) ekf_state1 = EKFState(ncv, 2*x, 2*P, t) ekf_state2 = ekf_state1.predict(dt) assert ekf_state2.dynamic_model.__class__ == NcvContinuous measurement = PositionMeasurement( mean=torch.rand(d), cov=torch.eye(d), time=t + dt) log_likelihood = ekf_state2.log_likelihood_of_update(measurement) assert (log_likelihood < 0.).all() ekf_state3, (dz, S) = ekf_state2.update(measurement) assert dz.shape == (measurement.dimension,) assert S.shape == (measurement.dimension, measurement.dimension) assert_not_equal(ekf_state3.mean, ekf_state2.mean, prec=1e-5)
def test_NcvContinuous(): framerate = 100 # Hz dt = 1.0 / framerate d = 6 ncv = NcvContinuous(dimension=d, sa2=2.0) assert ncv.dimension == d assert ncv.dimension_pv == d assert ncv.num_process_noise_parameters == 1 x = torch.rand(d) y = ncv(x, dt) assert_equal(y[0], x[0] + dt * x[d // 2]) dx = ncv.geodesic_difference(x, y) assert_not_equal(dx, torch.zeros(d)) x_pv = ncv.mean2pv(x) assert len(x_pv) == d assert_equal(x, x_pv) P = torch.eye(d) P_pv = ncv.cov2pv(P) assert P_pv.shape == (d, d) assert_equal(P, P_pv) Q = ncv.process_noise_cov(dt) Q1 = ncv.process_noise_cov(dt) # Test caching. assert_equal(Q, Q1) assert Q1.shape == (d, d) assert_cov_validity(Q1) dx = ncv.process_noise_dist(dt).sample() assert dx.shape == (ncv.dimension, )
def test_csis_parameter_update(): pyro.clear_param_store() guide = Guide() initial_parameters = {k: v.item() for k, v in guide.named_parameters()} csis = pyro.infer.CSIS(model, guide, pyro.optim.Adam({"lr": 1e-2})) csis.step() updated_parameters = {k: v.item() for k, v in guide.named_parameters()} for k, init_v in initial_parameters.items(): assert_not_equal(init_v, updated_parameters[k])
def test_csis_validation_batch(): pyro.clear_param_store() guide = Guide() csis = pyro.infer.CSIS(model, guide, pyro.optim.Adam({}), validation_batch_size=5) init_loss_1 = csis.validation_loss() init_loss_2 = csis.validation_loss() csis.step() next_loss = csis.validation_loss() assert_equal(init_loss_1, init_loss_2) assert_not_equal(init_loss_1, next_loss)
def test_counterfactual_query(intervene, observe, flip): # x -> y -> z -> w sites = ["x", "y", "z", "w"] observations = {"x": 1., "y": None, "z": 1., "w": 1.} interventions = {"x": None, "y": 0., "z": 2., "w": 1.} def model(): x = _item(pyro.sample("x", dist.Normal(0, 1))) y = _item(pyro.sample("y", dist.Normal(x, 1))) z = _item(pyro.sample("z", dist.Normal(y, 1))) w = _item(pyro.sample("w", dist.Normal(z, 1))) return dict(x=x, y=y, z=z, w=w) if not flip: if intervene: model = poutine.do(model, data=interventions) if observe: model = poutine.condition(model, data=observations) elif flip and intervene and observe: model = poutine.do(poutine.condition(model, data=observations), data=interventions) tr = poutine.trace(model).get_trace() actual_values = tr.nodes["_RETURN"]["value"] for name in sites: # case 1: purely observational query like poutine.condition if not intervene and observe: if observations[name] is not None: assert tr.nodes[name]['is_observed'] assert_equal(observations[name], actual_values[name]) assert_equal(observations[name], tr.nodes[name]['value']) if interventions[name] != observations[name]: assert_not_equal(interventions[name], actual_values[name]) # case 2: purely interventional query like old poutine.do elif intervene and not observe: assert not tr.nodes[name]['is_observed'] if interventions[name] is not None: assert_equal(interventions[name], actual_values[name]) assert_not_equal(observations[name], tr.nodes[name]['value']) assert_not_equal(interventions[name], tr.nodes[name]['value']) # case 3: counterfactual query mixing intervention and observation elif intervene and observe: if observations[name] is not None: assert tr.nodes[name]['is_observed'] assert_equal(observations[name], tr.nodes[name]['value']) if interventions[name] is not None: assert_equal(interventions[name], actual_values[name]) if interventions[name] != observations[name]: assert_not_equal(interventions[name], tr.nodes[name]['value'])