def test_sample_prior_predictive_var_names(model_fixture): model, observed = model_fixture prior = forward_sampling.sample_prior_predictive(model(), var_names=["model/sd"], sample_shape=()) assert set(prior) == set(["model/sd"]) prior = forward_sampling.sample_prior_predictive( model(), var_names=["model/x", "model/y"], sample_shape=()) assert set(prior) == set(["model/x", "model/y"]) # Assert we can get the values of observeds if we ask for them explicitly # even if sample_from_observed_fixture is False prior = forward_sampling.sample_prior_predictive( model(), var_names=["model/x", "model/y"], sample_shape=(), sample_from_observed=False) assert set(prior) == set(["model/x", "model/y"]) assert np.all(prior["model/x"] == observed) # Assert an exception is raised if we pass wrong names model_func = model() expected_message = "Some of the supplied var_names are not defined in the supplied model {}.\nList of unknown var_names: {}".format( model_func, ["X"]) with pytest.raises(ValueError, match=re.escape(expected_message)): prior = forward_sampling.sample_prior_predictive( model_func, var_names=["X", "model/y"], sample_shape=())
def test_sample_prior_predictive_int_sample_shape(model_fixture, n_draws_fixture): model, observed = model_fixture prior_int = forward_sampling.sample_prior_predictive( model(), sample_shape=n_draws_fixture) prior_tuple = forward_sampling.sample_prior_predictive( model(), sample_shape=(n_draws_fixture, )) assert all((prior_int[k].shape == v.shape for k, v in prior_tuple.items()))
def test_sample_prior_predictive_on_glm(glm_model_fixture, use_auto_batching_fixture, sample_shape_fixture): model, is_vectorized_model, core_shapes = glm_model_fixture if not use_auto_batching_fixture and not is_vectorized_model and len( sample_shape_fixture) > 0: with pytest.raises(AssertionError): prior = forward_sampling.sample_prior_predictive( model(), sample_shape=sample_shape_fixture, use_auto_batching=use_auto_batching_fixture, ).prior_predictive for k, v in core_shapes.items(): # The (1,) comes from trace_to_arviz imposed chain axis assert prior[k].shape == (1, ) + sample_shape_fixture + v else: prior = forward_sampling.sample_prior_predictive( model(), sample_shape=sample_shape_fixture, use_auto_batching=use_auto_batching_fixture).prior_predictive for k, v in core_shapes.items(): # The (1,) comes from trace_to_arviz imposed chain axis assert prior[k].shape == (1, ) + sample_shape_fixture + v
def test_sample_prior_predictive(model_fixture, sample_shape_fixture, sample_from_observed_fixture): model, observed = model_fixture prior = forward_sampling.sample_prior_predictive( model(), sample_shape_fixture, sample_from_observed_fixture) if sample_from_observed_fixture: assert set(["model/sd", "model/x", "model/y", "model/mu", "model/dy"]) == set(prior) assert all( (value.shape == sample_shape_fixture for value in prior.values())) else: assert set(["model/sd", "model/y", "model/mu", "model/dy"]) == set(prior) assert all((value.shape == sample_shape_fixture for name, value in prior.items() if name in {"model/sd", "model/mu"})) assert all((value.shape == sample_shape_fixture + observed.shape for name, value in prior.items() if name in {"model/x", "model/y", "model/dy"})) assert np.allclose(prior["model/y"], observed, rtol=1e-5) assert np.allclose(prior["model/y"] * 2, prior["model/dy"])