def print_results(coef: jnp.ndarray, interval_size: float = 0.95) -> None: """ Print the confidence interval for the effect size with interval_size probability mass. """ baseline_response = expit(coef[:, 0]) response_with_calls = expit(coef[:, 0] + coef[:, 1]) impact_on_probability = hpdi(response_with_calls - baseline_response, prob=interval_size) effect_of_gender = hpdi(coef[:, 2], prob=interval_size) print( f"There is a {interval_size * 100}% probability that calling customers " "increases the chance they'll make a purchase by " f"{(100 * impact_on_probability[0]):.2} to {(100 * impact_on_probability[1]):.2} percentage points." ) print( f"There is a {interval_size * 100}% probability the effect of gender on the log odds of conversion " f"lies in the interval ({effect_of_gender[0]:.2}, {effect_of_gender[1]:.2f})." " Since this interval contains 0, we can conclude gender does not impact the conversion rate." )
def _save_results( x: jnp.ndarray, prior_samples: Dict[str, jnp.ndarray], posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], num_train: int, ) -> None: root = pathlib.Path("./data/seasonal") root.mkdir(exist_ok=True) jnp.savez(root / "piror_samples.npz", **prior_samples) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) x_pred = posterior_predictive["x"] x_pred_trn = x_pred[:, :num_train] x_hpdi_trn = diagnostics.hpdi(x_pred_trn) t_train = np.arange(num_train) x_pred_tst = x_pred[:, num_train:] x_hpdi_tst = diagnostics.hpdi(x_pred_tst) num_test = x_pred_tst.shape[1] t_test = np.arange(num_train, num_train + num_test) prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] plt.figure(figsize=(12, 6)) plt.plot(x.ravel(), label="ground truth", color=colors[0]) plt.plot(t_train, x_pred_trn.mean(0)[:, 0], label="prediction", color=colors[1]) plt.fill_between(t_train, x_hpdi_trn[0, :, 0, 0], x_hpdi_trn[1, :, 0, 0], alpha=0.3, color=colors[1]) plt.plot(t_test, x_pred_tst.mean(0)[:, 0], label="forecast", color=colors[2]) plt.fill_between(t_test, x_hpdi_tst[0, :, 0, 0], x_hpdi_tst[1, :, 0, 0], alpha=0.3, color=colors[2]) plt.ylim(x.min() - 0.5, x.max() + 0.5) plt.legend() plt.tight_layout() plt.savefig(root / "prediction.png") plt.close()
def _save_results( x: jnp.ndarray, prior_samples: Dict[str, jnp.ndarray], posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], ) -> None: root = pathlib.Path("./data/kalman") root.mkdir(exist_ok=True) jnp.savez(root / "piror_samples.npz", **prior_samples) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) len_train = x.shape[0] x_pred_trn = posterior_predictive["x"][:, :len_train] x_hpdi_trn = diagnostics.hpdi(x_pred_trn) x_pred_tst = posterior_predictive["x"][:, len_train:] x_hpdi_tst = diagnostics.hpdi(x_pred_tst) len_test = x_pred_tst.shape[1] prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] plt.figure(figsize=(12, 6)) plt.plot(x.ravel(), label="ground truth", color=colors[0]) t_train = np.arange(len_train) plt.plot(t_train, x_pred_trn.mean(0).ravel(), label="prediction", color=colors[1]) plt.fill_between(t_train, x_hpdi_trn[0].ravel(), x_hpdi_trn[1].ravel(), alpha=0.3, color=colors[1]) t_test = np.arange(len_train, len_train + len_test) plt.plot(t_test, x_pred_tst.mean(0).ravel(), label="forecast", color=colors[2]) plt.fill_between(t_test, x_hpdi_tst[0].ravel(), x_hpdi_tst[1].ravel(), alpha=0.3, color=colors[2]) plt.legend() plt.tight_layout() plt.savefig(root / "kalman.png") plt.close()
def moments(self): if self._moments is None: mean = onp.mean(self.trace, axis=0)[..., None] hpdi_0, hpdi_1 = hpdi(self.trace, prob=self.alpha) names, values = ['lower', 'mean', 'upper'], [hpdi_0, mean, hpdi_1] self._moments = dict(zip(names, values)) return self._moments
def plot_results( time_series: np.ndarray, y: np.ndarray, posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], test_index: int, root: pathlib.Path, ) -> None: forecast_marginal = posterior_predictive["y_forecast"] y_pred = jnp.mean(forecast_marginal, axis=0) smape = jnp.mean(jnp.abs(y_pred - y) / (y_pred + y)) * 200 msqrt = jnp.sqrt(jnp.mean((y_pred - y)**2)) hpd_low, hpd_high = hpdi(forecast_marginal) plt.figure(figsize=(8, 4)) plt.plot(time_series, y) plt.plot(time_series, y_pred, lw=2) plt.fill_between(time_series, hpd_low, hpd_high, alpha=0.3) plt.title( f"Forecasting lynx dataset with SGT (sMAPE: {smape:.2f}, RMSE: {msqrt:.2f})" ) plt.tight_layout() plt.savefig(root / "plot.png") plt.close()
def sample_posterior_predictive( self, df: pd.DataFrame, hdpi: bool = False, hdpi_interval: float = 0.9, rng_key: np.ndarray = None, ) -> typing.Union[pd.Series, pd.DataFrame]: """Obtain samples from the posterior predictive. Parameters ---------- df : pd.DataFrame Source dataframe. hdpi : bool Option to include lower/upper bound of the highest posterior density interval. Returns a dataframe if true, a series if false. Default False. hdpi_interval : float HDPI width. Default 0.9. rng_key : two-element ndarray. Optional rng key, will be randomly splitted if not provided. Returns ------- pd.Series or pd.DataFrame Forecasts. Will be a series with the name of the dv if no HDPI. Will be a dataframe if HDPI is included. """ # get rng key rng_key_ = (self.split_rand_key() if rng_key is None else rng_key.astype("uint32")) # check for nulls null_cols = columns_with_null_data(self.transform(df)) if null_cols: raise exceptions.NullDataFound(*null_cols) # do it predictions = infer.Predictive(self.model, self.samples_flat)(rng_key_, df=df)[self.dv] if not hdpi: return pd.Series(predictions.mean(axis=0), index=df.index, name=self.dv) hdpi = diagnostics.hpdi(predictions, hdpi_interval) return pd.DataFrame( { self.dv: predictions.mean(axis=0), "hdpi_lower": hdpi[0, :], "hdpi_upper": hdpi[1, :], }, index=df.index, )
def _save_results( y: np.ndarray, mcmc: infer.MCMC, prior: Dict[str, jnp.ndarray], posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], *, var_names: Optional[List[str]] = None, ) -> None: root = pathlib.Path("./data/boston_pca_reg") root.mkdir(exist_ok=True) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) # Arviz numpyro_data = az.from_numpyro( mcmc, prior=prior, posterior_predictive=posterior_predictive, ) az.plot_trace(numpyro_data, var_names=var_names) plt.savefig(root / "trace.png") plt.close() az.plot_ppc(numpyro_data) plt.legend(loc="upper right") plt.savefig(root / "ppc.png") plt.close() # Prediction y_pred = posterior_predictive["y"] y_hpdi = diagnostics.hpdi(y_pred) train_len = int(len(y) * 0.8) prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] plt.figure(figsize=(12, 6)) plt.plot(y, color=colors[0]) plt.plot(y_pred.mean(axis=0), color=colors[1]) plt.fill_between(np.arange(len(y)), y_hpdi[0], y_hpdi[1], color=colors[1], alpha=0.3) plt.axvline(train_len, linestyle="--", color=colors[2]) plt.xlabel("Index [a.u.]") plt.ylabel("Target [a.u.]") plt.savefig(root / "prediction.png") plt.close()
def main() -> None: df = load_dataset() rng_key = random.PRNGKey(0) rng_key, rng_key_ = random.split(rng_key) # Inference posterior kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000) mcmc.run(rng_key_, marriage=df["MarriageScaled"].values, divorce=df["DivorceScaled"].values) mcmc.print_summary() samples_1 = mcmc.get_samples() # Compute empirical posterior distribution posterior_mu = ( jnp.expand_dims(samples_1["a"], -1) + jnp.expand_dims(samples_1["bM"], -1) * df["MarriageScaled"].values ) mean_mu = jnp.mean(posterior_mu, axis=0) hpdi_mu = hpdi(posterior_mu, 0.9) print(mean_mu, hpdi_mu) # Posterior predictive distribution rng_key, rng_key_ = random.split(rng_key) predictive = Predictive(model, samples_1) predictions = predictive(rng_key_, marriage=df["MarriageScaled"].values)["obs"] df["MeanPredictions"] = jnp.mean(predictions, axis=0) print(df.head()) # Predictive utility with effect handlers predict_fn = vmap( lambda rng_key, samples: predict( rng_key, samples, model, marriage=df["MarriageScaled"].values ) ) predictions_1 = predict_fn(random.split(rng_key_, 2000), samples_1) mean_pred = jnp.mean(predictions_1, axis=0) print(mean_pred) # Posterior predictive density rng_key, rng_key_ = random.split(rng_key) lpp_dns = log_pred_density( rng_key_, samples_1, model, marriage=df["MarriageScaled"].values, divorce=df["DivorceScaled"].values, ) print("Log posterior predictive density", lpp_dns)
def get_mean_and_ci(x, key_name, id_var, prob=0.9, axis=0): mean_val = x.mean(axis=axis) low_ci, up_ci = hpdi(x, prob=prob, axis=axis) df = pd.DataFrame(mean_val) df[id_var] = df.index.values df_mean_long = pd.melt(df, var_name=key_name, value_name='mean_val', id_vars=id_var) multi_idx = pd.MultiIndex.from_frame(df_mean_long[[id_var, key_name]]) df_mean_long.index = multi_idx df_mean_long.drop(columns=[key_name, id_var], inplace=True) df = pd.DataFrame(low_ci) df[id_var] = df.index.values df_low_long = pd.melt(df, var_name=key_name, value_name='lower_cl', id_vars=id_var) df_low_long.index = multi_idx df_low_long.drop(columns=[key_name, id_var], inplace=True) df = pd.DataFrame(up_ci) df[id_var] = df.index.values df_up_long = pd.melt(df, var_name=key_name, value_name='upper_cl', id_vars=id_var) df_up_long.index = multi_idx df_up_long.drop(columns=[key_name, id_var], inplace=True) df = df_mean_long.join(df_low_long).join(df_up_long) return df
def main(args): # generate artifical dataset rng_key, _ = random.split(random.PRNGKey(0)) T = args.T t = jnp.linspace(0, T + args.future, (T + args.future) * N_POINTS_PER_UNIT) y = jnp.sin( 2 * np.pi * t) + 0.3 * t + jax.random.normal(rng_key, t.shape) * 0.1 n_seasons = N_POINTS_PER_UNIT y_train = y[:-args.future * N_POINTS_PER_UNIT] t_test = t[-args.future * N_POINTS_PER_UNIT:] # do inference rng_key, _ = random.split(random.PRNGKey(1)) samples = run_inference(holt_winters, args, rng_key, y_train, n_seasons) # do prediction rng_key, _ = random.split(random.PRNGKey(2)) preds = predict(holt_winters, args, samples, rng_key, y_train, n_seasons) mean_preds = preds.mean(axis=0) hpdi_preds = hpdi(preds) # make plots fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) # plot true data and predictions ax.plot(t, y, color="blue", label="True values") ax.plot(t_test, mean_preds, color="orange", label="Mean predictions") ax.fill_between(t_test, *hpdi_preds, color="orange", alpha=0.2, label="90% CI") ax.set(xlabel="time", ylabel="y", title="Holt-Winters Exponential Smoothing") ax.legend() plt.savefig("holt_winters_plot.pdf")
def hpdi(self, param, which="posterior", *args, **kwargs): return hpdi(self._select(which)[param], *args, **kwargs)
#mn=dist.MultivariateNormal(loc=ave, covariance_matrix=cov) #key,subkey=random.split(key) #mk = numpyro.sample('a',mn,rng_key=key) try: mk = smn(mean=ave, cov=cov, allow_singular=True).rvs(1).T mkgp = smn(mean=gp, cov=cov, allow_singular=True).rvs(1).T marrs_gp.append(mkgp) marrs.append(mk) except: print("SMN not worked") marrs = np.array(marrs) marrs_gp = np.array(marrs_gp) median_mu1 = np.median(marrs, axis=0) hpdi_mu1 = hpdi(marrs, 0.9) median_mu1_gp = np.median(marrs_gp, axis=0) hpdi_mu1_gp = hpdi(marrs_gp, 0.9) red = (1.0 + 28.07 / 300000.0) #for annotation fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20, 6.0)) ax.plot(wavd1[::-1], median_mu1, color="C0") ax.plot(wavd1[::-1], median_mu1_gp, color="C2") ax.plot(wavd1[::-1], fobs1, "+", color="C1", label="data") #annotation for some lines ax.plot([22913.3 * red, 22913.3 * red], [0.6, 0.75], color="C0", lw=1) ax.plot([22918.07 * red, 22918.07 * red], [0.6, 0.77], color="C1", lw=1) ax.plot([22955.67 * red, 22955.67 * red], [0.6, 0.68], color="C2", lw=1)
num_warmup, num_samples = 500, 1000 #num_warmup, num_samples = 100, 300 kernel = NUTS(model_c, forward_mode_differentiation=True) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(rng_key_, nu1=nusd1, y1=fobs1, e1=err1) print("end HMC") #Post-processing posterior_sample = mcmc.get_samples() np.savez("npz/savepos.npz", [posterior_sample]) pred = Predictive(model_c, posterior_sample, return_sites=["y1"]) nu_1 = nus1 predictions = pred(rng_key_, nu1=nu_1, y1=None, e1=err1) median_mu1 = jnp.median(predictions["y1"], axis=0) hpdi_mu1 = hpdi(predictions["y1"], 0.9) np.savez("npz/saveplotpred.npz", [wavd1, fobs1, err1, median_mu1, hpdi_mu1]) red = (1.0 + 28.07 / 300000.0) #for annotation fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20, 6.0)) ax.plot(wavd1[::-1], median_mu1, color="C0") ax.plot(wavd1[::-1], fobs1, "+", color="C1", label="data") #annotation for some lines ax.plot([22913.3 * red, 22913.3 * red], [0.6, 0.75], color="C0", lw=1) ax.plot([22918.07 * red, 22918.07 * red], [0.6, 0.77], color="C1", lw=1) ax.plot([22955.67 * red, 22955.67 * red], [0.6, 0.68], color="C2", lw=1) plt.text(22913.3 * red, 0.55, "A", color="C0",
columns=["Name", "Locus_tag", "gene_lookup"]).set_index('locus_tag'), on='locus_tag') #%% for_gini = onp.zeros(samples['new_beta'].shape) for i in range(for_gini.shape[0]): for_gini[i, ...] = h.prep_for_gini(samples['new_beta'][i, ...]) #%% gini_arr = onp.zeros((for_gini.shape[0], for_gini.shape[1])) for i in range(gini_arr.shape[0]): gini_arr[i, :] = h.gini(for_gini[i, :, :]) # %% mean_gini = np.mean(gini_arr, axis=0) gini_low, gini_up = hpdi(gini_arr, prob=0.9, axis=0) gini_df = pd.DataFrame({ 'mean_val': mean_gini, 'lower_cl': gini_low, 'upper_cl': gini_up, 'locus_tag': [k for k in gene_lookup.keys()] }) gini_df['gene'] = gini_df.locus_tag.replace(locus_tag_lookup) gini_df = gini_df.join(gene_info_df.set_index('locus_tag'), on='locus_tag').drop(columns='Name') gini_df = gini_df.sort_values('mean_val') gini_df['x_vals'] = onp.arange(gini_df.shape[0]) # %% ph.plot_ginis(gini_df)
kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains, progress_bar=True) mcmc.run(rng_key_, X=X_train_processed,y_obs=y_train,ndims=ndims,ndata=ndata) mcmc.print_summary() samples_3 = mcmc.get_samples() predictive = Predictive(model, samples_3) predictions_3 = Predictive(model_se, samples_3)(rng_key_, X=X_test_processed, ndims=X_test_processed.shape[1], ndata=X_test_processed.shape[0])['y'] residuals_4 = y_test - predictions_3 residuals_mean = np.mean(residuals_4, axis=0) residuals_hpdi = hpdi(residuals_4, 0.9) err = residuals_hpdi[1] - residuals_mean fig, ax = plt.subplots(nrows=1, ncols=1) # Plot Residuals ax.errorbar(residuals_mean, y_test, xerr=err, marker='o', ms=5, mew=4, ls='none', alpha=0.8)
rng_key = random.PRNGKey(0) rng_key, rng_key_ = random.split(rng_key) num_warmup, num_samples = 300, 600 kernel = NUTS(model_c, forward_mode_differentiation=True) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(rng_key_, nu1=nusd, y1=nflux) # SAMPLING posterior_sample = mcmc.get_samples() np.savez('npz/savepos.npz', [posterior_sample]) pred = Predictive(model_c, posterior_sample, return_sites=['y1']) predictions = pred(rng_key_, nu1=nusd, y1=None) median_mu1 = jnp.median(predictions['y1'], axis=0) hpdi_mu1 = hpdi(predictions['y1'], 0.9) err = np.ones_like(nflux) np.savez('npz/saveplotpred.npz', [wavd, nflux, err, median_mu1, hpdi_mu1]) # PLOT fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20, 6.0)) ax.plot(wavd[::-1], median_mu1, color='C0') ax.plot(wavd[::-1], nflux, '+', color='black', label='data') ax.fill_between(wavd[::-1], hpdi_mu1[0], hpdi_mu1[1], alpha=0.3, interpolate=True, color='C0', label='90% area')
def main(): # Ground truth values ground_truth = { 'beta': [0.5, 0.5, 0.0, -0.1], 'z_init': [1, 2, 0, 1], 'tau': 0.2 } sample = 2000 init = len(ground_truth['z_init']) adj_sample = sample - init epsilon = onp.concatenate([ onp.array(ground_truth['z_init']), ground_truth['tau'] * onp.random.randn(adj_sample) ], axis=0) X = onp.random.randn(sample) obs = onp.array(ar_signal(ground_truth['beta'], sample, epsilon)) test_sample = 1000 y_train, y_test = np.array(obs[:sample - test_sample], dtype=np.float32), obs[sample - test_sample:] data_ = {'obs': y_train, 'X': X, 'n_coefs': 4} # Inference num_samples = 5000 nuts_kernel = NUTS(ar_k) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=num_samples) rng_key = random.PRNGKey(0) mcmc.run(rng_key, **data_, extra_fields=('potential_energy', )) mcmc.print_summary() samples = mcmc.get_samples() plot_inference(ground_truth, samples) rng_keys = random.split(random.PRNGKey(3), samples["tau"].shape[0]) forecast_marginal = vmap(lambda rng_key, sample: forecast( y_test.shape[0], rng_key, sample, y_train, data_['n_coefs']))(rng_keys, samples) y_pred = np.mean(forecast_marginal, axis=0) sMAPE = np.mean(np.abs(y_pred - y_test) / (y_pred + y_test)) * 200 msqrt = np.sqrt(np.mean((y_pred - y_test)**2)) print("sMAPE: {:.2f}, rmse: {:.2f}".format(sMAPE, msqrt)) plt.figure(figsize=(8, 4)) plt.plot(range(sample), obs) t_future = range(sample - test_sample, sample) hpd_low, hpd_high = hpdi(forecast_marginal, prob=0.95) plt.plot(t_future, y_pred, lw=2) plt.fill_between(t_future, hpd_low, hpd_high, alpha=0.3) plt.title("Forecasting AR model (90% HPDI)") plt.show() pass
from exojax.spec import rtransfer as rt #ATMOSPHERE NP = 100 Parr, dParr, k = rt.pressure_layer(NP=NP) T0 = 1295.0 #K alpha = 0.099 Tarrc = T0 * (Parr)**alpha p = np.load("npz/savepos.npz", allow_pickle=True)["arr_0"][0] Tsample = p["Tarr"] T0sample = p["T0"] from numpyro.diagnostics import hpdi mean_muy = jnp.mean(Tsample * T0, axis=0) hpdi_muy = hpdi(Tsample + T0sample[:, None], 0.90, axis=0) fig = plt.figure(figsize=(5, 7)) for i in range(0, np.shape(Tsample)[0]): T0 = T0sample[i] Tarr = Tsample[i, :] plt.plot(Tarr + T0, Parr, alpha=0.05, color="green", rasterized=True) plt.plot(Tarrc, Parr, alpha=1.0, color="black", lw=1, label="best-fit power law") #plt.fill_betweenx(Parr, hpdi_muy[0], hpdi_muy[1], alpha=0.3, interpolate=True,color="C0")
# Predictions for Model 3. rng_key, rng_key_ = random.split(rng_key) predictions_3 = Predictive(model, samples_3)(rng_key_, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values)['obs'] y = np.arange(50) fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 16)) pred_mean = np.mean(predictions_3, axis=0) pred_hpdi = hpdi(predictions_3, 0.9) residuals_3 = dset.DivorceScaled.values - predictions_3 residuals_mean = np.mean(residuals_3, axis=0) residuals_hpdi = hpdi(residuals_3, 0.9) idx = np.argsort(residuals_mean) # Plot posterior predictive ax[0].plot(np.zeros(50), y, '--') ax[0].errorbar(pred_mean[idx], y, xerr=pred_hpdi[1, idx] - pred_mean[idx], marker='o', ms=5, mew=4, ls='none', alpha=0.8) ax[0].plot(dset.DivorceScaled.values[idx], y, marker='o', ls='none', color='gray') ax[0].set(xlabel='Posterior Predictive (red) vs. Actuals (gray)', ylabel='State', title='Posterior Predictive with 90% CI') ax[0].set_yticks(y) ax[0].set_yticklabels(dset.Loc.values[idx], fontsize=10);
def hpdi(self, param, *args, **kwargs): """Returns the highest predictive density interval of param.""" return hpdi(self.dist(param, *args), **kwargs)
def test_hpdi(): x = np.random.normal(size=20000) assert_allclose(hpdi(x, prob=0.8), np.quantile(x, [0.1, 0.9]), atol=0.01) x = np.random.exponential(size=20000) assert_allclose(hpdi(x, prob=0.2), np.array([0.0, 0.22]), atol=0.01)
def _save_results( x: jnp.ndarray, betas: jnp.ndarray, prior_samples: Dict[str, jnp.ndarray], posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], num_train: int, ) -> None: root = pathlib.Path("./data/dlm") root.mkdir(exist_ok=True) jnp.savez(root / "piror_samples.npz", **prior_samples) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) x_pred = posterior_predictive["x"] x_pred_trn = x_pred[:, :num_train] x_hpdi_trn = diagnostics.hpdi(x_pred_trn) t_train = np.arange(num_train) x_pred_tst = x_pred[:, num_train:] x_hpdi_tst = diagnostics.hpdi(x_pred_tst) num_test = x_pred_tst.shape[1] t_test = np.arange(num_train, num_train + num_test) t_axis = np.arange(num_train + num_test) w_pred = posterior_predictive["weight"] w_hpdi = diagnostics.hpdi(w_pred) prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] beta_dim = betas.shape[-1] plt.figure(figsize=(8, 12)) for i in range(beta_dim + 1): plt.subplot(beta_dim + 1, 1, i + 1) if i == 0: plt.plot(x[:, 0], label="ground truth", color=colors[0]) plt.plot(t_train, x_pred_trn.mean(0)[:, 0], label="prediction", color=colors[1]) plt.fill_between(t_train, x_hpdi_trn[0, :, 0, 0], x_hpdi_trn[1, :, 0, 0], alpha=0.3, color=colors[1]) plt.plot(t_test, x_pred_tst.mean(0)[:, 0], label="forecast", color=colors[2]) plt.fill_between(t_test, x_hpdi_tst[0, :, 0, 0], x_hpdi_tst[1, :, 0, 0], alpha=0.3, color=colors[2]) plt.title("ground truth", fontsize=16) else: plt.plot(betas[:, 0, i - 1], label="ground truth", color=colors[0]) plt.plot(t_axis, w_pred.mean(0)[:, i - 1], label="prediction", color=colors[1]) plt.fill_between(t_axis, w_hpdi[0, :, i - 1, 0], w_hpdi[1, :, i - 1, 0], alpha=0.3, color=colors[1]) plt.title(f"coef_{i - 1}", fontsize=16) plt.legend(loc="upper left") plt.tight_layout() plt.savefig(root / "prediction.png") plt.close()