示例#1
0
    def sample_from_ppd(rng_key):
        """ Samples a single parameter vector and
            num_record_samples_per_parameter_sample based on it.
        """
        parameter_sampling_rng, record_sampling_rng = jax.random.split(rng_key)

        # sample single parameter vector
        posterior_sampler = Predictive(guide,
                                       params=posterior_params,
                                       num_samples=1)
        posterior_samples = posterior_sampler(parameter_sampling_rng)
        # models always add a superfluous batch dimensions, squeeze it
        posterior_samples = {
            k: v.squeeze(0)
            for k, v in posterior_samples.items()
        }

        # sample num_record_samples_per_parameter_sample data samples
        ppd_sampler = Predictive(model, posterior_samples, batch_ndims=0)
        per_sample_rngs = jax.random.split(
            record_sampling_rng, num_record_samples_per_parameter_sample)
        ppd_samples = jax.vmap(ppd_sampler)(per_sample_rngs)
        # models always add a superfluous batch dimensions, squeeze it
        ppd_samples = {k: v.squeeze(1) for k, v in ppd_samples.items()}

        return ppd_samples
示例#2
0
    def _predict(self, home_team, away_team, dates, num_samples=100, seed=42):

        predictive = Predictive(
            self.model,
            num_samples=num_samples,
            posterior_samples=self.samples,
            return_sites=("home_goals", "away_goals"),
        )

        home_team = [home_team] if isinstance(home_team, str) else home_team
        away_team = [away_team] if isinstance(away_team, str) else away_team

        missing_teams = set(home_team + away_team) - set(
            self.team_to_index.keys())

        for team in missing_teams:
            new_index = max(self.team_to_index.values()) + 1
            self.team_to_index[team] = new_index
            self.index_to_team[new_index] = team
            self.n_teams += 1

        gameweek = (dates - self.min_date).dt.days // 7

        predictions = predictive.get_samples(random.PRNGKey(seed), home_team,
                                             away_team, gameweek)

        return predictions["home_goals"], predictions["away_goals"]
示例#3
0
def main() -> None:

    # Data
    num = 8
    y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
    sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

    # Random key
    rng_key = random.PRNGKey(0)

    # Inference
    nuts_kernel = NUTS(model_noncentered)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
    mcmc.run(rng_key, num, sigma, y=y, extra_fields=("potential_energy", ))
    print(mcmc.print_summary())

    # Extra
    pe = mcmc.get_extra_fields()["potential_energy"]
    print(f"Expected log joint density: {np.mean(-pe):.2f}")

    # Prediction
    predictive = Predictive(model_pred, num_samples=100)
    samples = predictive(random.PRNGKey(1))
    print("prior", np.mean(samples["obs"]))

    predictive = Predictive(model_pred, mcmc.get_samples())
    samples = predictive(random.PRNGKey(1))
    print("posterior", np.mean(samples["obs"]))
    def predictive(self,
                   n,
                   nu=None,
                   nu_err=None,
                   n_pred=None,
                   **kwargs) -> dict:
        """[summary]

        Args:
            model_args (tuple): Positional arguments to pass to the model
                callable.
            model_kwargs (dict): Keyword arguments to pass to the model
                callable.
            **kwargs: Kwargs to pass to Predictive.

        Returns:
            dict: [description]
        """
        posterior_samples = kwargs.pop("posterior_samples", None)
        num_samples = kwargs.pop("num_samples", None)
        batch_ndims = kwargs.pop("batch_ndims", 2)
        return_sites = kwargs.pop("return_sites", None)

        posterior = {} if posterior_samples is None else posterior_samples
        if return_sites is None:
            trace = self.get_trace(pred=True)
            return_sites = []
            for k, site in trace.items():
                # Only return non-observed sample sites not in samples and
                # all deterministic sites.
                if site["type"] == "sample":
                    if not site["is_observed"] and k not in posterior:
                        return_sites.append(k)
                elif site["type"] == "deterministic":
                    return_sites.append(k)

        predictive = Predictive(
            self.model,
            posterior_samples=posterior_samples,
            num_samples=num_samples,
            return_sites=return_sites,
            batch_ndims=batch_ndims,
            **kwargs,
        )

        if predictive.batch_ndims == 0:
            # Fix bug in Predictive for computing batch shape
            predictive._batch_shape = ()

        rng_key, self._rng_key = random.split(self._rng_key)
        samples = predictive(rng_key, n, nu=nu, nu_err=nu_err, n_pred=n_pred)
        # self._update_args_kwargs(model_args, model_kwargs)
        return samples
示例#5
0
文件: handler.py 项目: sagar87/numgp
    def fit(self, *args, **kwargs):
        num_epochs = kwargs.pop("num_epochs", self.num_epochs)
        log_freq = kwargs.pop("log_freq", self.log_freq)

        if self.init_state is None:
            self.init_state = self.svi.init(self.rng_key, *args)

        if log_freq <= 0:
            state, loss = self._fit(num_epochs, *args)
            self._update_state(state, loss)
        else:
            steps, rest = num_epochs // log_freq, num_epochs % log_freq

            for step in range(steps):
                state, loss = self._fit(log_freq, *args)
                self._log(log_freq * (step + 1), loss[-1])
                self._update_state(state, loss)

            if rest > 0:
                state, loss = self._fit(rest, *args)
                self._update_state(state, loss)

        self.params = self.svi.get_params(state)
        predictive = Predictive(
            self.model,
            guide=self.guide,
            params=self.params,
            num_samples=self.num_samples,
            **kwargs,
        )
        self.posterior = predictive(self.rng_key, *args)
示例#6
0
def main(args):
    annotators, annotations = get_data()
    model = NAME_TO_MODEL[args.model]
    data = ((annotations, ) if model in [multinomial, item_difficulty] else
            (annotators, annotations))

    mcmc = MCMC(
        NUTS(model),
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(random.PRNGKey(0), *data)
    mcmc.print_summary()

    posterior_samples = mcmc.get_samples()
    predictive = Predictive(model, posterior_samples, infer_discrete=True)
    discrete_samples = predictive(random.PRNGKey(1), *data)

    item_class = vmap(lambda x: jnp.bincount(x, length=4),
                      in_axes=1)(discrete_samples["c"].squeeze(-1))
    print("Histogram of the predicted class of each item:")
    row_format = "{:>10}" * 5
    print(row_format.format("", *["c={}".format(i) for i in range(4)]))
    for i, row in enumerate(item_class):
        print(row_format.format(f"item[{i}]", *row))
示例#7
0
def main(args):
    rng_key = random.PRNGKey(0)

    # do inference with centered parameterization
    print("============================= Centered Parameterization ==============================")
    samples = run_inference(model, args, rng_key)

    # do inference with non-centered parameterization
    print("\n=========================== Non-centered Parameterization ============================")
    reparam_samples = run_inference(reparam_model, args, rng_key)
    # collect deterministic sites
    reparam_samples = Predictive(reparam_model, reparam_samples, return_sites=['x', 'y'])(
        random.PRNGKey(1))

    # make plots
    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 8))

    ax1.plot(samples['x'][:, 0], samples['y'], "go", alpha=0.3)
    ax1.set(xlim=(-20, 20), ylim=(-9, 9), ylabel='y',
            title='Funnel samples with centered parameterization')

    ax2.plot(reparam_samples['x'][:, 0], reparam_samples['y'], "go", alpha=0.3)
    ax2.set(xlim=(-20, 20), ylim=(-9, 9), xlabel='x[0]', ylabel='y',
            title='Funnel samples with non-centered parameterization')

    plt.savefig('funnel_plot.pdf')
    plt.tight_layout()
示例#8
0
    def conditional_from_guide(self, guide, params, *args, **kwargs):
        pred_noise, diag = kwargs.pop("pred_noise",
                                      False), kwargs.pop("diag", False)
        self._get_var_names(*args, **kwargs)
        predictive = Predictive(
            self.model,
            guide=guide,
            params=params,
            num_samples=self.num_samples,
            return_sites=(
                self.gp,
                self.mean,
                self.cond,
                self.Kss,
                self.Kns,
                self.Ksx,
                self.Kxx,
                self.Knx,
                self.y,
            ),
        )

        self.cond_params = predictive(PRNGKey(self.rng_key), *args)
        mu, var = self._build_conditional(self.cond_params, pred_noise, diag)
        return mu, var
示例#9
0
def main(args):
    _, fetch = load_dataset(LYNXHARE, shuffle=False)
    year, data = fetch()  # data is in hare -> lynx order

    # use dense_mass for better mixing rate
    mcmc = MCMC(NUTS(model, dense_mass=True),
                args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data))
    mcmc.print_summary()

    # predict populations
    y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"]
    pop_pred = jnp.exp(y_pred)
    mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0)
    plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67)
    plt.plot(year, data[:, 1], "bx", label="true lynx")
    plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
    plt.plot(year, mu[:, 1], "b--", label="pred lynx")
    plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2)
    plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3)
    plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)")
    plt.title("Posterior predictive (80% CI) with predator-prey pattern.")
    plt.legend()

    plt.savefig("ode_plot.pdf")
    plt.tight_layout()
示例#10
0
def main(args):
    data = {}
    pred_datas = {}
    rng_key = random.PRNGKey(args.rng_seed)
    for aa in args.amino_acids:
        rng_key, inf_key, pred_key = random.split(rng_key, 3)
        data[aa] = fetch_aa_dihedrals(aa)
        num_mix_comp = num_mix_comps(aa)

        # Use kmeans to initialize the chain location.
        kmeans = KMeans(num_mix_comp)
        kmeans.fit(data[aa])
        means = {
            "phi_loc": kmeans.cluster_centers_[:, 0],
            "psi_loc": kmeans.cluster_centers_[:, 1],
        }

        posterior_samples = {
            "ss": run_hmc(inf_key, ss_model, data[aa], num_mix_comp, args,
                          means)
        }
        predictive = Predictive(ss_model,
                                posterior_samples["ss"],
                                parallel=True)

        pred_datas[aa] = predictive(pred_key, None, 1,
                                    num_mix_comp)["phi_psi"].reshape(-1, 2)

    ramachandran_plot(data, pred_datas, args.amino_acids)
示例#11
0
def main(args):
    _, fetch_train = load_dataset(UCBADMIT, split="train", shuffle=False)
    dept, male, applications, admit = fetch_train()
    rng_key, rng_key_predict = random.split(random.PRNGKey(1))
    zs = run_inference(dept, male, applications, admit, rng_key, args)
    pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male,
                                      applications)["probs"]
    header = "=" * 30 + "glmm - TRAIN" + "=" * 30
    print_results(header, pred_probs, dept, male, admit / applications)

    # make plots
    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

    ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate")
    ax.errorbar(
        range(1, 13),
        jnp.mean(pred_probs, 0),
        jnp.std(pred_probs, 0),
        fmt="o",
        c="k",
        mfc="none",
        ms=7,
        elinewidth=1,
        label=r"mean $\pm$ std",
    )
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+")
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+")
    ax.set(
        xlabel="cases",
        ylabel="admit rate",
        title="Posterior Predictive Check with 90% CI",
    )
    ax.legend()

    plt.savefig("ucbadmit_plot.pdf")
示例#12
0
 def svi_predict(model, guide, params, args, X):
     predictive = Predictive(model=model,
                             guide=guide,
                             params=params,
                             num_samples=args.num_samples)
     predictions = predictive(PRNGKey(1), X=X, Y=None)
     svi_predictions = jnp.rint(predictions["Y"].mean(0))
     return svi_predictions
示例#13
0
文件: base.py 项目: gcgibson/covid
 def prior(self, num_samples=1000, rng_key=PRNGKey(2), **args):
     
     predictive = Predictive(self, posterior_samples={}, num_samples=num_samples)        
     
     args = dict(self.args, **args) # passed args take precedence        
     self.prior_samples = predictive(rng_key, **args)
     
     return self.prior_samples
示例#14
0
文件: base.py 项目: elray1/covid
    def forecast(self, num_samples=1000, rng_key=PRNGKey(4), **args):
        if self.mcmc_samples is None:
            raise RuntimeError("run inference first")

        predictive = Predictive(self, posterior_samples=self.mcmc_samples)

        args = dict(self.args, **args)
        return predictive(rng_key, **self.obs, **args)
示例#15
0
def predict(
    model: Callable,
    at_bats: jnp.ndarray,
    posterior_samples: jnp.ndarray,
    rng_key: jnp.ndarray,
) -> Dict[str, jnp.ndarray]:

    predictive = Predictive(model, posterior_samples=posterior_samples)
    return predictive(rng_key, at_bats)
示例#16
0
    def predictive(self, rng_key=PRNGKey(3), **args):
        '''Draw samples from in-sample predictive distribution'''

        if self.mcmc_samples is None:
            raise RuntimeError("run inference first")

        predictive = Predictive(self, posterior_samples=self.mcmc_samples)

        args = dict(self.args, **args)
        return predictive(rng_key, **args)
示例#17
0
 def _init_tau(self, rng_key, tau_prior, num_samples=5000):
     predictive = Predictive(tau_prior, num_samples=num_samples)
     pred = predictive(rng_key)
     log_tau = pred["log_tau"] - 6  # Convert from seconds to mega seconds
     loc = log_tau.mean(axis=0)
     scale = log_tau.std(axis=0, ddof=1)
     return (
         distribution((loc[0], scale[0])),  # tau_he
         distribution((loc[1], scale[1])),  # tau_cz
     )
示例#18
0
    def transform(self, views: Iterable[np.ndarray], y=None, **kwargs):
        """
        Predict the latent variables that generate the data in views using the sampled model parameters

        :param views: list/tuple of numpy arrays or array likes with the same number of rows (samples)
        """
        check_is_fitted(self, attributes=["posterior_samples"])
        return Predictive(self._model, self.posterior_samples, return_sites=["z"])(
            self.rng_key, views
        )["z"]
示例#19
0
    def load_numpyro_divorce():
        model_uri = os.path.join(numpyro_divorce.details.local_folder, "numpyro-divorce.json")

        with open(model_uri) as model_file:
            raw_samples = json.load(model_file)

        samples = {}
        for k, v in raw_samples.items():
            samples[k] = np.array(v)

        numpyro_divorce.context.predictive_dist = Predictive(model_function, samples)
示例#20
0
 def get_inference_data(self, data, eight_schools_params):
     posterior_samples = data.obj.get_samples()
     model = data.obj.sampler.model
     posterior_predictive = Predictive(
         model, posterior_samples)(PRNGKey(1), eight_schools_params["J"],
                                   eight_schools_params["sigma"])
     prior = Predictive(model,
                        num_samples=500)(PRNGKey(2),
                                         eight_schools_params["J"],
                                         eight_schools_params["sigma"])
     return from_numpyro(
         posterior=data.obj,
         prior=prior,
         posterior_predictive=posterior_predictive,
         coords={"school": np.arange(eight_schools_params["J"])},
         dims={
             "theta": ["school"],
             "eta": ["school"]
         },
     )
示例#21
0
文件: handler.py 项目: sagar87/numgp
    def get_posterior_predictive(self, *args, **kwargs):
        """kwargs -> Predictive, args -> predictive"""
        num_samples = kwargs.pop("num_samples", self.num_samples)

        predictive = Predictive(
            self.model,
            guide=self.guide,
            params=self.params,
            num_samples=num_samples,
            **kwargs,
        )
        self.posterior_predictive = predictive(self.rng_key, *args)
示例#22
0
def main(args):
    data = load_data()

    inf_key, pred_key, data_key = random.split(random.PRNGKey(args.rng_key), 3)
    # normalize data and labels to zero mean unit variance!
    x, xtr_mean, xtr_std = normalize(data.xtr)
    y, ytr_mean, ytr_std = normalize(data.ytr)

    rng_key, inf_key = random.split(inf_key)

    stein = SteinVI(
        model,
        AutoDelta(model, init_loc_fn=partial(init_to_uniform, radius=0.1)),
        Adagrad(0.05),
        Trace_ELBO(
            20),  # estimate elbo with 20 particles (not stein particles!)
        RBFKernel(),
        repulsion_temperature=args.repulsion,
        num_particles=args.num_particles,
    )
    start = time()

    # use keyword params for static (shape etc.)!
    result = stein.run(
        rng_key,
        args.max_iter,
        x,
        y,
        hidden_dim=args.hidden_dim,
        subsample_size=args.subsample_size,
        progress_bar=args.progress_bar,
    )
    time_taken = time() - start

    pred = Predictive(
        model,
        guide=stein.guide,
        params=stein.get_params(result.state),
        num_samples=1,
        batch_ndims=1,  # stein particle dimension
    )
    xte, _, _ = normalize(
        data.xte, xtr_mean,
        xtr_std)  # use train data statistics when accessing generalization
    preds = pred(pred_key, xte,
                 subsample_size=xte.shape[0])["y"].reshape(-1, xte.shape[0])

    y_pred = jnp.mean(preds, 0) * ytr_std + ytr_mean
    rmse = jnp.sqrt(jnp.mean((y_pred - data.yte)**2))

    print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}")
    print(rf"RMSE: {rmse:.2f}")
示例#23
0
    async def load(self) -> bool:
        model_uri = self._settings.parameters.uri
        with open(model_uri) as model_file:
            raw_samples = json.load(model_file)

        self._samples = {}
        for k, v in raw_samples.items():
            self._samples[k] = np.array(v)

        self._predictive = Predictive(self._model, self._samples)

        self.ready = True
        return self.ready
示例#24
0
def main() -> None:

    df = load_dataset()

    rng_key = random.PRNGKey(0)
    rng_key, rng_key_ = random.split(rng_key)

    # Inference posterior
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000)
    mcmc.run(rng_key_, marriage=df["MarriageScaled"].values, divorce=df["DivorceScaled"].values)
    mcmc.print_summary()
    samples_1 = mcmc.get_samples()

    # Compute empirical posterior distribution
    posterior_mu = (
        jnp.expand_dims(samples_1["a"], -1)
        + jnp.expand_dims(samples_1["bM"], -1) * df["MarriageScaled"].values
    )

    mean_mu = jnp.mean(posterior_mu, axis=0)
    hpdi_mu = hpdi(posterior_mu, 0.9)
    print(mean_mu, hpdi_mu)

    # Posterior predictive distribution
    rng_key, rng_key_ = random.split(rng_key)
    predictive = Predictive(model, samples_1)
    predictions = predictive(rng_key_, marriage=df["MarriageScaled"].values)["obs"]
    df["MeanPredictions"] = jnp.mean(predictions, axis=0)
    print(df.head())

    # Predictive utility with effect handlers
    predict_fn = vmap(
        lambda rng_key, samples: predict(
            rng_key, samples, model, marriage=df["MarriageScaled"].values
        )
    )
    predictions_1 = predict_fn(random.split(rng_key_, 2000), samples_1)
    mean_pred = jnp.mean(predictions_1, axis=0)
    print(mean_pred)

    # Posterior predictive density
    rng_key, rng_key_ = random.split(rng_key)
    lpp_dns = log_pred_density(
        rng_key_,
        samples_1,
        model,
        marriage=df["MarriageScaled"].values,
        divorce=df["DivorceScaled"].values,
    )
    print("Log posterior predictive density", lpp_dns)
示例#25
0
文件: baseball.py 项目: ucals/numpyro
def predict(model, at_bats, hits, z, rng_key, player_names, train=True):
    header = model.__name__ + (' - TRAIN' if train else ' - TEST')
    predictions = Predictive(model, posterior_samples=z)(rng_key, at_bats)['obs']
    print_results('=' * 30 + header + '=' * 30,
                  predictions,
                  player_names,
                  at_bats,
                  hits)
    if not train:
        post_loglik = log_likelihood(model, z, at_bats, hits)['obs']
        # computes expected log predictive density at each data point
        exp_log_density = logsumexp(post_loglik, axis=0) - jnp.log(jnp.shape(post_loglik)[0])
        # reports log predictive density of all test points
        print('\nLog pointwise predictive density: {:.2f}\n'.format(exp_log_density.sum()))
示例#26
0
    def predict(self, *args, **kwargs):
        """kwargs -> Predictive, args -> predictive"""
        num_samples = kwargs.pop("num_samples", self.num_samples)
        rng_key = kwargs.pop("rng_key", self.rng_key)

        predictive = Predictive(
            self.model,
            guide=self.guide,
            params=self.params,
            num_samples=num_samples,
            **kwargs,
        )

        self.predictive = Posterior(predictive(rng_key, *args), self.to_numpy)
示例#27
0
    def predict(self, X: DeviceArray, **kwargs) -> DeviceArray:
        """Predict the parameters of a model specified by `return_sites`

        Args:
            X: input data
            kwargs: keyword arguments for numpro `Predictive`

        Returns:
            samples for all sample sites
        """
        self.init_svi(X, lr=0.)  # dummy initialization
        predictive = Predictive(self.model,
                                guide=self.guide,
                                params=self.model_params,
                                **kwargs)
        samples = predictive(self.rng_key, X)
        return samples
示例#28
0
def test_pickle_autoguide(guide_class):
    x = np.random.poisson(1.0, size=(100,))

    guide = guide_class(poisson_regression)
    optim = numpyro.optim.Adam(1e-2)
    svi = SVI(poisson_regression, guide, optim, numpyro.infer.Trace_ELBO())
    svi_result = svi.run(random.PRNGKey(1), 3, x, len(x))
    pickled_guide = pickle.loads(pickle.dumps(guide))

    predictive = Predictive(
        poisson_regression,
        guide=pickled_guide,
        params=svi_result.params,
        num_samples=1,
        return_sites=["param", "x"],
    )
    samples = predictive(random.PRNGKey(1), None, 1)
    assert set(samples.keys()) == {"param", "x"}
示例#29
0
def sample_posterior_with_predictive(rng_key: random.PRNGKey,
                                     model,
                                     data: np.ndarray,
                                     Nsamples: int = 1000,
                                     alpha: float = 1,
                                     sigma: float = 0,
                                     T: int = 10):

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=Nsamples, num_warmup=NUM_WARMUP)

    mcmc.run(rng_key, data=data, alpha=alpha, sigma=sigma, T=T)
    samples = mcmc.get_samples()

    predictive = Predictive(model,
                            posterior_samples=samples,
                            return_sites=["z"])
    return predictive(rng_key, data=data, alpha=alpha, sigma=sigma, T=T)["z"]
示例#30
0
def test_scan():
    def model(T=10, q=1, r=1, phi=0.0, beta=0.0):
        def transition(state, i):
            x0, mu0 = state
            x1 = numpyro.sample("x", dist.Normal(phi * x0, q))
            mu1 = beta * mu0 + x1
            y1 = numpyro.sample("y", dist.Normal(mu1, r))
            numpyro.deterministic("y2", y1 * 2)
            return (x1, mu1), (x1, y1)

        mu0 = x0 = numpyro.sample("x_0", dist.Normal(0, q))
        y0 = numpyro.sample("y_0", dist.Normal(mu0, r))

        _, xy = scan(transition, (x0, mu0), jnp.arange(T))
        x, y = xy

        return jnp.append(x0, x), jnp.append(y0, y)

    T = 10
    num_samples = 100
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples)
    mcmc.run(random.PRNGKey(0), T=T)
    assert set(mcmc.get_samples()) == {"x", "y", "y2", "x_0", "y_0"}
    mcmc.print_summary()

    samples = mcmc.get_samples()
    x = samples.pop("x")[0]  # take 1 sample of x
    # this tests for the composition of condition and substitute
    # this also tests if we can use `vmap` for predictive.
    future = 5
    predictive = Predictive(
        numpyro.handlers.condition(model, {"x": x}),
        samples,
        return_sites=["x", "y", "y2"],
        parallel=True,
    )
    result = predictive(random.PRNGKey(1), T=T + future)
    expected_shape = (num_samples, T + future)
    assert result["x"].shape == expected_shape
    assert result["y"].shape == expected_shape
    assert result["y2"].shape == expected_shape
    assert_allclose(result["x"][:, :T], jnp.broadcast_to(x, (num_samples, T)))
    assert_allclose(result["y"][:, :T], samples["y"])