def test_none_yield(): def model(): yield None yield dist.Normal("n", 0, 1) with pytest.raises(pm.flow.executor.EvaluationError) as e: pm.evaluate_model_transformed(model()) assert e.match("processed in evaluation")
def test_unable_to_create_duplicate_variable(): def invdalid_model(): yield pm.distributions.HalfNormal("n", 1, transform=pm.distributions.transforms.Log()) yield pm.distributions.Normal("n", 0, 1) with pytest.raises(pm.flow.executor.EvaluationError) as e: pm.evaluate_model(invdalid_model()) assert e.match("duplicate") with pytest.raises(pm.flow.executor.EvaluationError) as e: pm.evaluate_model_transformed(invdalid_model()) assert e.match("duplicate")
def test_transformed_executor_logp_tensorflow(transformed_model): norm_log = tfd.TransformedDistribution(tfd.HalfNormal(1), bij.Invert(bij.Exp())) _, state = pm.evaluate_model_transformed(transformed_model(), values=dict(__log_n=-math.pi)) np.testing.assert_allclose( state.collect_log_prob(), norm_log.log_prob(-math.pi), equal_nan=False ) _, state = pm.evaluate_model_transformed(transformed_model(), values=dict(n=math.exp(-math.pi))) np.testing.assert_allclose( state.collect_log_prob(), norm_log.log_prob(-math.pi), equal_nan=False )
def test_transformed_model_transformed_executor_with_passed_value(transformed_model): _, state = pm.evaluate_model_transformed(transformed_model(), values=dict(n=1.0)) assert set(state.all_values) == {"n", "__log_n"} assert set(state.transformed_values) == {"__log_n"} assert set(state.untransformed_values) == {"n"} np.testing.assert_allclose(state.all_values["__log_n"], 0.0) _, state = pm.evaluate_model_transformed(transformed_model(), values=dict(__log_n=0.0)) assert set(state.all_values) == {"n", "__log_n"} assert set(state.transformed_values) == {"__log_n"} assert set(state.untransformed_values) == {"n"} np.testing.assert_allclose(state.all_values["n"], 1.0)
def test_unnamed_return_2(): @pm.model(name=None) def a_model(): return (yield pm.HalfNormal("n", 1, transform=pm.distributions.transforms.Log())) _, state = pm.evaluate_model(a_model(name="b_model")) assert "b_model" in state.deterministics_values with pytest.raises(pm.flow.executor.EvaluationError) as e: pm.evaluate_model(a_model()) assert e.match("unnamed") with pytest.raises(pm.flow.executor.EvaluationError) as e: pm.evaluate_model_transformed(a_model()) assert e.match("unnamed")
def test_observed_do_not_produce_transformed_values_case_override( transformed_model_with_observed): _, state = pm.evaluate_model_transformed(transformed_model_with_observed(), observed=dict(n=None)) assert not state.observed_values assert set(state.transformed_values) == {"__log_n"} assert set(state.untransformed_values) == {"n"}
def test_observed_do_not_produce_transformed_values_case_programmatic( transformed_model): _, state = pm.evaluate_model_transformed(transformed_model(), observed=dict(n=1.0)) assert set(state.observed_values) == {"n"} assert not state.transformed_values assert not state.untransformed_values
def test_observed_cant_mix_with_transformed_and_raises_an_error( transformed_model_with_observed): with pytest.raises(pm.flow.executor.EvaluationError) as e: _, state = pm.evaluate_model_transformed( transformed_model_with_observed(), values=dict(__log_n=0.0)) assert e.match("{'n': None}") assert e.match("'__log_n' from transformed values")
def test_uncatched_exception_works(): @pm.model def a_model(): try: yield 1 except: pass yield pm.distributions.HalfNormal("n", 1, transform=pm.distributions.transforms.Log()) with pytest.raises(pm.flow.executor.StopExecution) as e: pm.evaluate_model(a_model()) assert e.match("something_bad") with pytest.raises(pm.flow.executor.StopExecution) as e: pm.evaluate_model_transformed(a_model()) assert e.match("something_bad")
def test_as_sampling_state_works_if_transformed_exec( complex_model_with_observed): _, state = pm.evaluate_model_transformed(complex_model_with_observed()) sampling_state = state.as_sampling_state() assert not sampling_state.transformed_values assert set(sampling_state.observed_values) == {"complex_model/a/n"} assert set(sampling_state.untransformed_values) == {"complex_model/n"}
def test_as_sampling_state_works_observed_is_set_to_none(complex_model_with_observed): _, state = pm.evaluate_model_transformed( complex_model_with_observed(), observed={"complex_model/a/n": None} ) sampling_state = state.as_sampling_state() assert set(sampling_state.transformed_values) == {"complex_model/a/__log_n"} assert not sampling_state.observed_values assert set(sampling_state.untransformed_values) == {"complex_model/n"}
def test_observed_do_not_produce_transformed_values_case_override_with_set_value( transformed_model_with_observed, ): _, state = pm.evaluate_model_transformed(transformed_model_with_observed(), values=dict(n=1.0), observed=dict(n=None)) assert not state.observed_values assert set(state.transformed_values) == {"__log_n"} assert set(state.untransformed_values) == {"n"} np.testing.assert_allclose(state.all_values["__log_n"], 0.0) _, state = pm.evaluate_model_transformed(transformed_model_with_observed(), values=dict(__log_n=0.0), observed=dict(n=None)) assert not state.observed_values assert set(state.transformed_values) == {"__log_n"} assert set(state.untransformed_values) == {"n"} np.testing.assert_allclose(state.all_values["n"], 1.0)
def test_transformed_model_transformed_executor(transformed_model): _, state = pm.evaluate_model_transformed(transformed_model()) assert set(state.all_values) == {"n", "__log_n"} assert set(state.transformed_values) == {"__log_n"} assert set(state.untransformed_values) == {"n"} assert not state.observed_values assert np.allclose(state.untransformed_values["n"], math.exp(state.transformed_values["__log_n"]))
def test_deterministics_in_nested_model(deterministics_in_nested_models): ( model, expected_untransformed, expected_transformed, expected_deterministics, deterministic_mapping, ) = deterministics_in_nested_models _, state = pm.evaluate_model_transformed(model()) assert set(state.untransformed_values) == expected_untransformed assert set(state.transformed_values) == expected_transformed assert set(state.deterministics) == expected_deterministics for deterministic, (inputs, op) in deterministic_mapping.items(): np.testing.assert_allclose( state.deterministics[deterministic], op(*[state.untransformed_values[i] for i in inputs]), )
def test_posterior_predictive_executor(model_with_observed_fixture): model, observed, core_ppc_shapes, _ = model_with_observed_fixture _, prior_state = pm.evaluate_model_transformed(model(), observed=observed) _, ppc_state = pm.evaluate_model_posterior_predictive(model(), observed=observed) # Assert that a normal evaluation has all observeds and the values match # to the observations assert len(prior_state.observed_values) == 3 for var, val in observed.items(): assert np.all(prior_state.all_values[var] == val) # Assert that a posterior predictive evaluation has no observed values # but the shapes of the samples match the supplied observed shapes assert len(ppc_state.observed_values) == 0 for var, shape in core_ppc_shapes.items(): assert (collections.ChainMap( ppc_state.all_values, ppc_state.deterministics)[var].numpy().shape == shape) if var in observed: assert np.any(ppc_state.all_values[var] != val)
# Test shapes, should be all 3: def print_dist_shapes(st): for name, dist in itertools.chain( st.discrete_distributions.items(), st.continuous_distributions.items(), ): if dist.log_prob(st.all_values[name]).shape != (3, ): log.warning( f"False shape: {dist.log_prob(st.all_values[name]).shape}, {name}" ) for p in st.potentials: if p.value.shape != (3, ): log.warning(f"False shape: {p.value.shape} {p.name}") _, sample_state = pm.evaluate_model_transformed(this_model, sample_shape=(3, )) print_dist_shapes(sample_state) """ # 2. MCMC Sampling """ begin_time = time.time() log.info("start") num_chains = 3 trace_tuning, trace = pm.sample( this_model, num_samples=60, num_samples_binning=10, burn_in_min=10, burn_in=100, use_auto_batching=False,
def test_observed_do_not_produce_transformed_values( transformed_model_with_observed): _, state = pm.evaluate_model_transformed(transformed_model_with_observed()) assert set(state.observed_values) == {"n"} assert not state.transformed_values assert not state.untransformed_values