def sample_from_ppd(rng_key): """ Samples a single parameter vector and num_record_samples_per_parameter_sample based on it. """ parameter_sampling_rng, record_sampling_rng = jax.random.split(rng_key) # sample single parameter vector posterior_sampler = Predictive(guide, params=posterior_params, num_samples=1) posterior_samples = posterior_sampler(parameter_sampling_rng) # models always add a superfluous batch dimensions, squeeze it posterior_samples = { k: v.squeeze(0) for k, v in posterior_samples.items() } # sample num_record_samples_per_parameter_sample data samples ppd_sampler = Predictive(model, posterior_samples, batch_ndims=0) per_sample_rngs = jax.random.split( record_sampling_rng, num_record_samples_per_parameter_sample) ppd_samples = jax.vmap(ppd_sampler)(per_sample_rngs) # models always add a superfluous batch dimensions, squeeze it ppd_samples = {k: v.squeeze(1) for k, v in ppd_samples.items()} return ppd_samples
def _predict(self, home_team, away_team, dates, num_samples=100, seed=42): predictive = Predictive( self.model, num_samples=num_samples, posterior_samples=self.samples, return_sites=("home_goals", "away_goals"), ) home_team = [home_team] if isinstance(home_team, str) else home_team away_team = [away_team] if isinstance(away_team, str) else away_team missing_teams = set(home_team + away_team) - set( self.team_to_index.keys()) for team in missing_teams: new_index = max(self.team_to_index.values()) + 1 self.team_to_index[team] = new_index self.index_to_team[new_index] = team self.n_teams += 1 gameweek = (dates - self.min_date).dt.days // 7 predictions = predictive.get_samples(random.PRNGKey(seed), home_team, away_team, gameweek) return predictions["home_goals"], predictions["away_goals"]
def main() -> None: # Data num = 8 y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) # Random key rng_key = random.PRNGKey(0) # Inference nuts_kernel = NUTS(model_noncentered) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000) mcmc.run(rng_key, num, sigma, y=y, extra_fields=("potential_energy", )) print(mcmc.print_summary()) # Extra pe = mcmc.get_extra_fields()["potential_energy"] print(f"Expected log joint density: {np.mean(-pe):.2f}") # Prediction predictive = Predictive(model_pred, num_samples=100) samples = predictive(random.PRNGKey(1)) print("prior", np.mean(samples["obs"])) predictive = Predictive(model_pred, mcmc.get_samples()) samples = predictive(random.PRNGKey(1)) print("posterior", np.mean(samples["obs"]))
def predictive(self, n, nu=None, nu_err=None, n_pred=None, **kwargs) -> dict: """[summary] Args: model_args (tuple): Positional arguments to pass to the model callable. model_kwargs (dict): Keyword arguments to pass to the model callable. **kwargs: Kwargs to pass to Predictive. Returns: dict: [description] """ posterior_samples = kwargs.pop("posterior_samples", None) num_samples = kwargs.pop("num_samples", None) batch_ndims = kwargs.pop("batch_ndims", 2) return_sites = kwargs.pop("return_sites", None) posterior = {} if posterior_samples is None else posterior_samples if return_sites is None: trace = self.get_trace(pred=True) return_sites = [] for k, site in trace.items(): # Only return non-observed sample sites not in samples and # all deterministic sites. if site["type"] == "sample": if not site["is_observed"] and k not in posterior: return_sites.append(k) elif site["type"] == "deterministic": return_sites.append(k) predictive = Predictive( self.model, posterior_samples=posterior_samples, num_samples=num_samples, return_sites=return_sites, batch_ndims=batch_ndims, **kwargs, ) if predictive.batch_ndims == 0: # Fix bug in Predictive for computing batch shape predictive._batch_shape = () rng_key, self._rng_key = random.split(self._rng_key) samples = predictive(rng_key, n, nu=nu, nu_err=nu_err, n_pred=n_pred) # self._update_args_kwargs(model_args, model_kwargs) return samples
def fit(self, *args, **kwargs): num_epochs = kwargs.pop("num_epochs", self.num_epochs) log_freq = kwargs.pop("log_freq", self.log_freq) if self.init_state is None: self.init_state = self.svi.init(self.rng_key, *args) if log_freq <= 0: state, loss = self._fit(num_epochs, *args) self._update_state(state, loss) else: steps, rest = num_epochs // log_freq, num_epochs % log_freq for step in range(steps): state, loss = self._fit(log_freq, *args) self._log(log_freq * (step + 1), loss[-1]) self._update_state(state, loss) if rest > 0: state, loss = self._fit(rest, *args) self._update_state(state, loss) self.params = self.svi.get_params(state) predictive = Predictive( self.model, guide=self.guide, params=self.params, num_samples=self.num_samples, **kwargs, ) self.posterior = predictive(self.rng_key, *args)
def main(args): annotators, annotations = get_data() model = NAME_TO_MODEL[args.model] data = ((annotations, ) if model in [multinomial, item_difficulty] else (annotators, annotations)) mcmc = MCMC( NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(random.PRNGKey(0), *data) mcmc.print_summary() posterior_samples = mcmc.get_samples() predictive = Predictive(model, posterior_samples, infer_discrete=True) discrete_samples = predictive(random.PRNGKey(1), *data) item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(discrete_samples["c"].squeeze(-1)) print("Histogram of the predicted class of each item:") row_format = "{:>10}" * 5 print(row_format.format("", *["c={}".format(i) for i in range(4)])) for i, row in enumerate(item_class): print(row_format.format(f"item[{i}]", *row))
def main(args): rng_key = random.PRNGKey(0) # do inference with centered parameterization print("============================= Centered Parameterization ==============================") samples = run_inference(model, args, rng_key) # do inference with non-centered parameterization print("\n=========================== Non-centered Parameterization ============================") reparam_samples = run_inference(reparam_model, args, rng_key) # collect deterministic sites reparam_samples = Predictive(reparam_model, reparam_samples, return_sites=['x', 'y'])( random.PRNGKey(1)) # make plots fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 8)) ax1.plot(samples['x'][:, 0], samples['y'], "go", alpha=0.3) ax1.set(xlim=(-20, 20), ylim=(-9, 9), ylabel='y', title='Funnel samples with centered parameterization') ax2.plot(reparam_samples['x'][:, 0], reparam_samples['y'], "go", alpha=0.3) ax2.set(xlim=(-20, 20), ylim=(-9, 9), xlabel='x[0]', ylabel='y', title='Funnel samples with non-centered parameterization') plt.savefig('funnel_plot.pdf') plt.tight_layout()
def conditional_from_guide(self, guide, params, *args, **kwargs): pred_noise, diag = kwargs.pop("pred_noise", False), kwargs.pop("diag", False) self._get_var_names(*args, **kwargs) predictive = Predictive( self.model, guide=guide, params=params, num_samples=self.num_samples, return_sites=( self.gp, self.mean, self.cond, self.Kss, self.Kns, self.Ksx, self.Kxx, self.Knx, self.y, ), ) self.cond_params = predictive(PRNGKey(self.rng_key), *args) mu, var = self._build_conditional(self.cond_params, pred_noise, diag) return mu, var
def main(args): _, fetch = load_dataset(LYNXHARE, shuffle=False) year, data = fetch() # data is in hare -> lynx order # use dense_mass for better mixing rate mcmc = MCMC(NUTS(model, dense_mass=True), args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data)) mcmc.print_summary() # predict populations y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"] pop_pred = jnp.exp(y_pred) mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0) plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67) plt.plot(year, data[:, 1], "bx", label="true lynx") plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67) plt.plot(year, mu[:, 1], "b--", label="pred lynx") plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2) plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3) plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)") plt.title("Posterior predictive (80% CI) with predator-prey pattern.") plt.legend() plt.savefig("ode_plot.pdf") plt.tight_layout()
def main(args): data = {} pred_datas = {} rng_key = random.PRNGKey(args.rng_seed) for aa in args.amino_acids: rng_key, inf_key, pred_key = random.split(rng_key, 3) data[aa] = fetch_aa_dihedrals(aa) num_mix_comp = num_mix_comps(aa) # Use kmeans to initialize the chain location. kmeans = KMeans(num_mix_comp) kmeans.fit(data[aa]) means = { "phi_loc": kmeans.cluster_centers_[:, 0], "psi_loc": kmeans.cluster_centers_[:, 1], } posterior_samples = { "ss": run_hmc(inf_key, ss_model, data[aa], num_mix_comp, args, means) } predictive = Predictive(ss_model, posterior_samples["ss"], parallel=True) pred_datas[aa] = predictive(pred_key, None, 1, num_mix_comp)["phi_psi"].reshape(-1, 2) ramachandran_plot(data, pred_datas, args.amino_acids)
def main(args): _, fetch_train = load_dataset(UCBADMIT, split="train", shuffle=False) dept, male, applications, admit = fetch_train() rng_key, rng_key_predict = random.split(random.PRNGKey(1)) zs = run_inference(dept, male, applications, admit, rng_key, args) pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)["probs"] header = "=" * 30 + "glmm - TRAIN" + "=" * 30 print_results(header, pred_probs, dept, male, admit / applications) # make plots fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate") ax.errorbar( range(1, 13), jnp.mean(pred_probs, 0), jnp.std(pred_probs, 0), fmt="o", c="k", mfc="none", ms=7, elinewidth=1, label=r"mean $\pm$ std", ) ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+") ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+") ax.set( xlabel="cases", ylabel="admit rate", title="Posterior Predictive Check with 90% CI", ) ax.legend() plt.savefig("ucbadmit_plot.pdf")
def svi_predict(model, guide, params, args, X): predictive = Predictive(model=model, guide=guide, params=params, num_samples=args.num_samples) predictions = predictive(PRNGKey(1), X=X, Y=None) svi_predictions = jnp.rint(predictions["Y"].mean(0)) return svi_predictions
def prior(self, num_samples=1000, rng_key=PRNGKey(2), **args): predictive = Predictive(self, posterior_samples={}, num_samples=num_samples) args = dict(self.args, **args) # passed args take precedence self.prior_samples = predictive(rng_key, **args) return self.prior_samples
def forecast(self, num_samples=1000, rng_key=PRNGKey(4), **args): if self.mcmc_samples is None: raise RuntimeError("run inference first") predictive = Predictive(self, posterior_samples=self.mcmc_samples) args = dict(self.args, **args) return predictive(rng_key, **self.obs, **args)
def predict( model: Callable, at_bats: jnp.ndarray, posterior_samples: jnp.ndarray, rng_key: jnp.ndarray, ) -> Dict[str, jnp.ndarray]: predictive = Predictive(model, posterior_samples=posterior_samples) return predictive(rng_key, at_bats)
def predictive(self, rng_key=PRNGKey(3), **args): '''Draw samples from in-sample predictive distribution''' if self.mcmc_samples is None: raise RuntimeError("run inference first") predictive = Predictive(self, posterior_samples=self.mcmc_samples) args = dict(self.args, **args) return predictive(rng_key, **args)
def _init_tau(self, rng_key, tau_prior, num_samples=5000): predictive = Predictive(tau_prior, num_samples=num_samples) pred = predictive(rng_key) log_tau = pred["log_tau"] - 6 # Convert from seconds to mega seconds loc = log_tau.mean(axis=0) scale = log_tau.std(axis=0, ddof=1) return ( distribution((loc[0], scale[0])), # tau_he distribution((loc[1], scale[1])), # tau_cz )
def transform(self, views: Iterable[np.ndarray], y=None, **kwargs): """ Predict the latent variables that generate the data in views using the sampled model parameters :param views: list/tuple of numpy arrays or array likes with the same number of rows (samples) """ check_is_fitted(self, attributes=["posterior_samples"]) return Predictive(self._model, self.posterior_samples, return_sites=["z"])( self.rng_key, views )["z"]
def load_numpyro_divorce(): model_uri = os.path.join(numpyro_divorce.details.local_folder, "numpyro-divorce.json") with open(model_uri) as model_file: raw_samples = json.load(model_file) samples = {} for k, v in raw_samples.items(): samples[k] = np.array(v) numpyro_divorce.context.predictive_dist = Predictive(model_function, samples)
def get_inference_data(self, data, eight_schools_params): posterior_samples = data.obj.get_samples() model = data.obj.sampler.model posterior_predictive = Predictive( model, posterior_samples)(PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]) prior = Predictive(model, num_samples=500)(PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]) return from_numpyro( posterior=data.obj, prior=prior, posterior_predictive=posterior_predictive, coords={"school": np.arange(eight_schools_params["J"])}, dims={ "theta": ["school"], "eta": ["school"] }, )
def get_posterior_predictive(self, *args, **kwargs): """kwargs -> Predictive, args -> predictive""" num_samples = kwargs.pop("num_samples", self.num_samples) predictive = Predictive( self.model, guide=self.guide, params=self.params, num_samples=num_samples, **kwargs, ) self.posterior_predictive = predictive(self.rng_key, *args)
def main(args): data = load_data() inf_key, pred_key, data_key = random.split(random.PRNGKey(args.rng_key), 3) # normalize data and labels to zero mean unit variance! x, xtr_mean, xtr_std = normalize(data.xtr) y, ytr_mean, ytr_std = normalize(data.ytr) rng_key, inf_key = random.split(inf_key) stein = SteinVI( model, AutoDelta(model, init_loc_fn=partial(init_to_uniform, radius=0.1)), Adagrad(0.05), Trace_ELBO( 20), # estimate elbo with 20 particles (not stein particles!) RBFKernel(), repulsion_temperature=args.repulsion, num_particles=args.num_particles, ) start = time() # use keyword params for static (shape etc.)! result = stein.run( rng_key, args.max_iter, x, y, hidden_dim=args.hidden_dim, subsample_size=args.subsample_size, progress_bar=args.progress_bar, ) time_taken = time() - start pred = Predictive( model, guide=stein.guide, params=stein.get_params(result.state), num_samples=1, batch_ndims=1, # stein particle dimension ) xte, _, _ = normalize( data.xte, xtr_mean, xtr_std) # use train data statistics when accessing generalization preds = pred(pred_key, xte, subsample_size=xte.shape[0])["y"].reshape(-1, xte.shape[0]) y_pred = jnp.mean(preds, 0) * ytr_std + ytr_mean rmse = jnp.sqrt(jnp.mean((y_pred - data.yte)**2)) print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}") print(rf"RMSE: {rmse:.2f}")
async def load(self) -> bool: model_uri = self._settings.parameters.uri with open(model_uri) as model_file: raw_samples = json.load(model_file) self._samples = {} for k, v in raw_samples.items(): self._samples[k] = np.array(v) self._predictive = Predictive(self._model, self._samples) self.ready = True return self.ready
def main() -> None: df = load_dataset() rng_key = random.PRNGKey(0) rng_key, rng_key_ = random.split(rng_key) # Inference posterior kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000) mcmc.run(rng_key_, marriage=df["MarriageScaled"].values, divorce=df["DivorceScaled"].values) mcmc.print_summary() samples_1 = mcmc.get_samples() # Compute empirical posterior distribution posterior_mu = ( jnp.expand_dims(samples_1["a"], -1) + jnp.expand_dims(samples_1["bM"], -1) * df["MarriageScaled"].values ) mean_mu = jnp.mean(posterior_mu, axis=0) hpdi_mu = hpdi(posterior_mu, 0.9) print(mean_mu, hpdi_mu) # Posterior predictive distribution rng_key, rng_key_ = random.split(rng_key) predictive = Predictive(model, samples_1) predictions = predictive(rng_key_, marriage=df["MarriageScaled"].values)["obs"] df["MeanPredictions"] = jnp.mean(predictions, axis=0) print(df.head()) # Predictive utility with effect handlers predict_fn = vmap( lambda rng_key, samples: predict( rng_key, samples, model, marriage=df["MarriageScaled"].values ) ) predictions_1 = predict_fn(random.split(rng_key_, 2000), samples_1) mean_pred = jnp.mean(predictions_1, axis=0) print(mean_pred) # Posterior predictive density rng_key, rng_key_ = random.split(rng_key) lpp_dns = log_pred_density( rng_key_, samples_1, model, marriage=df["MarriageScaled"].values, divorce=df["DivorceScaled"].values, ) print("Log posterior predictive density", lpp_dns)
def predict(model, at_bats, hits, z, rng_key, player_names, train=True): header = model.__name__ + (' - TRAIN' if train else ' - TEST') predictions = Predictive(model, posterior_samples=z)(rng_key, at_bats)['obs'] print_results('=' * 30 + header + '=' * 30, predictions, player_names, at_bats, hits) if not train: post_loglik = log_likelihood(model, z, at_bats, hits)['obs'] # computes expected log predictive density at each data point exp_log_density = logsumexp(post_loglik, axis=0) - jnp.log(jnp.shape(post_loglik)[0]) # reports log predictive density of all test points print('\nLog pointwise predictive density: {:.2f}\n'.format(exp_log_density.sum()))
def predict(self, *args, **kwargs): """kwargs -> Predictive, args -> predictive""" num_samples = kwargs.pop("num_samples", self.num_samples) rng_key = kwargs.pop("rng_key", self.rng_key) predictive = Predictive( self.model, guide=self.guide, params=self.params, num_samples=num_samples, **kwargs, ) self.predictive = Posterior(predictive(rng_key, *args), self.to_numpy)
def predict(self, X: DeviceArray, **kwargs) -> DeviceArray: """Predict the parameters of a model specified by `return_sites` Args: X: input data kwargs: keyword arguments for numpro `Predictive` Returns: samples for all sample sites """ self.init_svi(X, lr=0.) # dummy initialization predictive = Predictive(self.model, guide=self.guide, params=self.model_params, **kwargs) samples = predictive(self.rng_key, X) return samples
def test_pickle_autoguide(guide_class): x = np.random.poisson(1.0, size=(100,)) guide = guide_class(poisson_regression) optim = numpyro.optim.Adam(1e-2) svi = SVI(poisson_regression, guide, optim, numpyro.infer.Trace_ELBO()) svi_result = svi.run(random.PRNGKey(1), 3, x, len(x)) pickled_guide = pickle.loads(pickle.dumps(guide)) predictive = Predictive( poisson_regression, guide=pickled_guide, params=svi_result.params, num_samples=1, return_sites=["param", "x"], ) samples = predictive(random.PRNGKey(1), None, 1) assert set(samples.keys()) == {"param", "x"}
def sample_posterior_with_predictive(rng_key: random.PRNGKey, model, data: np.ndarray, Nsamples: int = 1000, alpha: float = 1, sigma: float = 0, T: int = 10): kernel = NUTS(model) mcmc = MCMC(kernel, num_samples=Nsamples, num_warmup=NUM_WARMUP) mcmc.run(rng_key, data=data, alpha=alpha, sigma=sigma, T=T) samples = mcmc.get_samples() predictive = Predictive(model, posterior_samples=samples, return_sites=["z"]) return predictive(rng_key, data=data, alpha=alpha, sigma=sigma, T=T)["z"]
def test_scan(): def model(T=10, q=1, r=1, phi=0.0, beta=0.0): def transition(state, i): x0, mu0 = state x1 = numpyro.sample("x", dist.Normal(phi * x0, q)) mu1 = beta * mu0 + x1 y1 = numpyro.sample("y", dist.Normal(mu1, r)) numpyro.deterministic("y2", y1 * 2) return (x1, mu1), (x1, y1) mu0 = x0 = numpyro.sample("x_0", dist.Normal(0, q)) y0 = numpyro.sample("y_0", dist.Normal(mu0, r)) _, xy = scan(transition, (x0, mu0), jnp.arange(T)) x, y = xy return jnp.append(x0, x), jnp.append(y0, y) T = 10 num_samples = 100 kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples) mcmc.run(random.PRNGKey(0), T=T) assert set(mcmc.get_samples()) == {"x", "y", "y2", "x_0", "y_0"} mcmc.print_summary() samples = mcmc.get_samples() x = samples.pop("x")[0] # take 1 sample of x # this tests for the composition of condition and substitute # this also tests if we can use `vmap` for predictive. future = 5 predictive = Predictive( numpyro.handlers.condition(model, {"x": x}), samples, return_sites=["x", "y", "y2"], parallel=True, ) result = predictive(random.PRNGKey(1), T=T + future) expected_shape = (num_samples, T + future) assert result["x"].shape == expected_shape assert result["y"].shape == expected_shape assert result["y2"].shape == expected_shape assert_allclose(result["x"][:, :T], jnp.broadcast_to(x, (num_samples, T))) assert_allclose(result["y"][:, :T], samples["y"])