def test_sample_posterior_predictive_on_glm(glm_model_fixture,
                                            use_auto_batching_fixture,
                                            sample_shape_fixture):
    model, is_vectorized_model, core_shapes = glm_model_fixture
    trace = pm.inference.utils.trace_to_arviz({
        # The transposition of the first two axis comes from trace_to_arviz
        # that does this to the output of `sample` to get (num_chains, num_samples, ...)
        # instead of (num_samples, num_chains, ...)
        k: tf.zeros((sample_shape_fixture[1], sample_shape_fixture[0]) +
                    sample_shape_fixture[2:] + v)
        for k, v in core_shapes.items() if k not in ["model/y"]
    })
    if (not use_auto_batching_fixture and not is_vectorized_model
            and (sample_shape_fixture not in [(), (1, ), (1, 1)]) > 0):
        with pytest.raises(Exception):
            # This can raise many types of Exceptions.
            # For example, ValueError when tfp distributions complain about
            # the parameter shapes being imcompatible or it can raise an
            # EvaluationError because the distribution shape is not compatible
            # with the supplied observations. Also, if in a @tf.function,
            # it can raise InvalidArgumentError.
            # Furthermore, in some cases, sampling may exit without errors, but
            # the resulting shapes will be wrong
            ppc = forward_sampling.sample_posterior_predictive(
                model(),
                trace=trace,
                use_auto_batching=use_auto_batching_fixture)
            for k, v in ppc.items():
                assert v.shape == sample_shape_fixture + core_shapes[k]
    else:
        ppc = forward_sampling.sample_posterior_predictive(
            model(), trace=trace,
            use_auto_batching=use_auto_batching_fixture).posterior_predictive
        for k, v in ppc.items():
            assert v.shape == sample_shape_fixture + core_shapes[k]
def test_vectorized_sample_posterior_predictive(vectorized_model_fixture,
                                                use_auto_batching_fixture,
                                                sample_shape_fixture):
    model, is_vectorized_model, core_shapes = vectorized_model_fixture
    trace = pm.inference.utils.trace_to_arviz({
        # The transposition of the first two axis comes from trace_to_arviz
        # that does this to the output of `sample` to get (num_chains, num_samples, ...)
        # instead of (num_samples, num_chains, ...)
        k: tf.zeros((sample_shape_fixture[1], sample_shape_fixture[0]) +
                    sample_shape_fixture[2:] + v)
        for k, v in core_shapes.items() if k not in ["model/x"]
    })
    if not use_auto_batching_fixture and not is_vectorized_model and len(
            sample_shape_fixture) > 0:
        with pytest.raises((ValueError, EvaluationError)):
            # This can raise ValueError when tfp distributions complain about
            # the parameter shapes being imcompatible or it can raise an
            # EvaluationError because the distribution shape is not compatible
            # with the supplied observations
            forward_sampling.sample_posterior_predictive(
                model(),
                trace=trace,
                use_auto_batching=use_auto_batching_fixture)
    else:
        ppc = forward_sampling.sample_posterior_predictive(
            model(), trace=trace,
            use_auto_batching=use_auto_batching_fixture).posterior_predictive
        for k, v in ppc.items():
            assert v.shape == sample_shape_fixture + core_shapes[k]
def test_sample_ppc_var_names(model_fixture):
    model, observed = model_fixture
    trace = pm.inference.utils.trace_to_arviz({
        "model/sd":
        tf.ones((10, 1), dtype="float32"),
        "model/y":
        tf.convert_to_tensor(observed[:, None]),
    })

    with pytest.raises(ValueError):
        forward_sampling.sample_posterior_predictive(model(),
                                                     trace,
                                                     var_names=[])

    with pytest.raises(KeyError):
        forward_sampling.sample_posterior_predictive(
            model(), trace, var_names=["name not in model!"])

    with pytest.raises(TypeError):
        trace.posterior["name not in model!"] = tf.constant(1.0)
        pm.sample_posterior_predictive(model(), trace)
    del trace.posterior["name not in model!"]

    var_names = ["model/sd", "model/x", "model/dy"]
    ppc = pm.sample_posterior_predictive(
        model(), trace, var_names=var_names).posterior_predictive
    assert set(var_names) == set(ppc)
    assert ppc["model/sd"].shape == trace.posterior["model/sd"].shape
def test_sample_ppc_corrupt_trace():
    @pm.model
    def model():
        x = yield pm.Normal("x", tf.ones(5), 1, reinterpreted_batch_ndims=1)
        y = yield pm.Normal("y", x, 1)

    trace1 = pm.inference.utils.trace_to_arviz(
        {"model/x": tf.ones((7, 1), dtype="float32")})

    trace2 = pm.inference.utils.trace_to_arviz({
        "model/x":
        tf.ones((1, 5), dtype="float32"),
        "model/y":
        tf.zeros((1, 1), dtype="float32")
    })
    with pytest.raises(EvaluationError):
        forward_sampling.sample_posterior_predictive(model(), trace1)
    with pytest.raises(EvaluationError):
        forward_sampling.sample_posterior_predictive(model(), trace2)
def test_posterior_predictive_on_root_variable(use_auto_batching_fixture):
    n_obs = 5
    n_samples = 6
    n_chains = 4

    @pm.model
    def model():
        x = yield pm.Normal(
            "x",
            np.zeros(n_obs, dtype="float32"),
            1,
            observed=np.zeros(n_obs, dtype="float32"),
            conditionally_independent=True,
            reinterpreted_batch_ndims=1,
        )
        beta = yield pm.Normal("beta", 0, 1, conditionally_independent=True)
        bias = yield pm.Normal("bias", 0, 1, conditionally_independent=True)
        mu = beta[..., None] * x + bias[..., None]
        yield pm.Normal(
            "obs",
            mu,
            1,
            observed=np.ones(n_obs, dtype="float32"),
            reinterpreted_batch_ndims=1,
        )

    trace = pm.mcmc.utils.trace_to_arviz({
        "model/beta":
        tf.zeros((n_samples, n_chains), dtype="float32"),
        "model/bias":
        tf.zeros((n_samples, n_chains), dtype="float32"),
    })
    ppc = forward_sampling.sample_posterior_predictive(
        model(), trace=trace,
        use_auto_batching=use_auto_batching_fixture).posterior_predictive
    if not use_auto_batching_fixture:
        _, state = pm.evaluate_model_posterior_predictive(
            model(), sample_shape=(n_chains, n_samples))
        assert state.untransformed_values["model/x"].numpy().shape == (
            n_chains,
            n_samples,
            n_obs,
        )
    assert ppc["model/obs"].shape == (n_chains, n_samples, n_obs)
    assert ppc["model/x"].shape == (n_chains, n_samples, n_obs)