Пример #1
0
def test_sample_prior_predictive_var_names(model_fixture):
    model, observed = model_fixture

    prior = forward_sampling.sample_prior_predictive(model(),
                                                     var_names=["model/sd"],
                                                     sample_shape=())
    assert set(prior) == set(["model/sd"])

    prior = forward_sampling.sample_prior_predictive(
        model(), var_names=["model/x", "model/y"], sample_shape=())
    assert set(prior) == set(["model/x", "model/y"])

    # Assert we can get the values of observeds if we ask for them explicitly
    # even if sample_from_observed_fixture is False
    prior = forward_sampling.sample_prior_predictive(
        model(),
        var_names=["model/x", "model/y"],
        sample_shape=(),
        sample_from_observed=False)
    assert set(prior) == set(["model/x", "model/y"])
    assert np.all(prior["model/x"] == observed)

    # Assert an exception is raised if we pass wrong names
    model_func = model()
    expected_message = "Some of the supplied var_names are not defined in the supplied model {}.\nList of unknown var_names: {}".format(
        model_func, ["X"])
    with pytest.raises(ValueError, match=re.escape(expected_message)):
        prior = forward_sampling.sample_prior_predictive(
            model_func, var_names=["X", "model/y"], sample_shape=())
Пример #2
0
def test_sample_prior_predictive_int_sample_shape(model_fixture,
                                                  n_draws_fixture):
    model, observed = model_fixture

    prior_int = forward_sampling.sample_prior_predictive(
        model(), sample_shape=n_draws_fixture)

    prior_tuple = forward_sampling.sample_prior_predictive(
        model(), sample_shape=(n_draws_fixture, ))

    assert all((prior_int[k].shape == v.shape for k, v in prior_tuple.items()))
Пример #3
0
def test_sample_prior_predictive_on_glm(glm_model_fixture,
                                        use_auto_batching_fixture,
                                        sample_shape_fixture):
    model, is_vectorized_model, core_shapes = glm_model_fixture
    if not use_auto_batching_fixture and not is_vectorized_model and len(
            sample_shape_fixture) > 0:
        with pytest.raises(AssertionError):
            prior = forward_sampling.sample_prior_predictive(
                model(),
                sample_shape=sample_shape_fixture,
                use_auto_batching=use_auto_batching_fixture,
            ).prior_predictive
            for k, v in core_shapes.items():
                # The (1,) comes from trace_to_arviz imposed chain axis
                assert prior[k].shape == (1, ) + sample_shape_fixture + v
    else:
        prior = forward_sampling.sample_prior_predictive(
            model(),
            sample_shape=sample_shape_fixture,
            use_auto_batching=use_auto_batching_fixture).prior_predictive
        for k, v in core_shapes.items():
            # The (1,) comes from trace_to_arviz imposed chain axis
            assert prior[k].shape == (1, ) + sample_shape_fixture + v
Пример #4
0
def test_sample_prior_predictive(model_fixture, sample_shape_fixture,
                                 sample_from_observed_fixture):
    model, observed = model_fixture

    prior = forward_sampling.sample_prior_predictive(
        model(), sample_shape_fixture, sample_from_observed_fixture)
    if sample_from_observed_fixture:
        assert set(["model/sd", "model/x", "model/y", "model/mu",
                    "model/dy"]) == set(prior)
        assert all(
            (value.shape == sample_shape_fixture for value in prior.values()))
    else:
        assert set(["model/sd", "model/y", "model/mu",
                    "model/dy"]) == set(prior)
        assert all((value.shape == sample_shape_fixture
                    for name, value in prior.items()
                    if name in {"model/sd", "model/mu"}))
        assert all((value.shape == sample_shape_fixture + observed.shape
                    for name, value in prior.items()
                    if name in {"model/x", "model/y", "model/dy"}))
        assert np.allclose(prior["model/y"], observed, rtol=1e-5)
        assert np.allclose(prior["model/y"] * 2, prior["model/dy"])