def j_summary(samples, ctype='hist', properties={'width': 800}): # print(vcov) if type(samples) == dict: print_summary(samples, 0.89, False) df = pd.DataFrame(samples).clean_names() else: print_summary(dict(zip(samples.columns, samples.T.values)), 0.89, False) df = samples display(df.corr()) df = df if len(df) < 5000 else df.sample(n=4000) base = alt.Chart(df).properties(height=30) if ctype == 'density': l = [ base.mark_line().transform_density( row, as_=[row, 'density'], ).encode(alt.X(f'{row}:Q'), alt.Y('density:Q')) for row in df.columns ] density = alt.vconcat(*l) return_chart = density if ctype == 'hist': hist = base.mark_bar().encode( alt.X(bin=alt.Bin(maxbins=20), field=alt.repeat("row"), type='quantitative'), y=alt.Y(title=None, aggregate='count', type='quantitative')).repeat(row=[c for c in df.columns]) return_chart = hist display(return_chart)
def print_summary(self, prob=0.9, exclude_deterministic=True): """ Print the statistics of posterior samples collected during running this MCMC instance. :param float prob: the probability mass of samples within the credible interval. :param bool exclude_deterministic: whether or not print out the statistics at deterministic sites. """ # Exclude deterministic sites by default sites = self._states[self._sample_field] if isinstance(sites, dict) and exclude_deterministic: state_sample_field = attrgetter(self._sample_field)( self._last_state) # XXX: there might be the case that state.z is not a dictionary but # its postprocessed value `sites` is a dictionary. # TODO: in general, when both `sites` and `state.z` are dictionaries, # they can have different key names, not necessary due to deterministic # behavior. We might revise this logic if needed in the future. if isinstance(state_sample_field, dict): sites = { k: v for k, v in self._states[self._sample_field].items() if k in state_sample_field } print_summary(sites, prob=prob) extra_fields = self.get_extra_fields() if "diverging" in extra_fields: print("Number of divergences: {}".format( jnp.sum(extra_fields["diverging"])))
def print_summary(self, prob=0.9, exclude_deterministic=True): # Exclude deterministic sites by default sites = self._states[self._sample_field] if isinstance(sites, dict) and exclude_deterministic: sites = { k: v for k, v in self._states[self._sample_field].items() if k in self._last_state.z } print_summary(sites, prob=prob) extra_fields = self.get_extra_fields() if 'diverging' in extra_fields: print("Number of divergences: {}".format( jnp.sum(extra_fields['diverging'])))
def analyze_post(post, method): print_summary(post, 0.95, False) fig, ax = plt.subplots() az.plot_forest(post, hdi_prob=0.95, figsize=(10, 4), ax=ax) plt.title(method) pml.savefig(f'multicollinear_forest_plot_{method}.pdf') plt.show() # post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000,)) fig, ax = plt.subplots() az.plot_pair(post, var_names=["br", "bl"], scatter_kwargs={"alpha": 0.1}, ax=ax) pml.savefig(f'multicollinear_joint_post_{method}.pdf') plt.title(method) plt.show() sum_blbr = post["bl"] + post["br"] fig, ax = plt.subplots() az.plot_kde(sum_blbr, label="sum of bl and br", ax=ax) plt.title(method) pml.savefig(f'multicollinear_sum_post_{method}.pdf') plt.show()
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() guide = AutoBNAFNormal( dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]) svi = SVI(dual_moon_model, guide, optim.Adam(0.003), ELBO()) svi_state = svi.init(random.PRNGKey(1)) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(2), params, sample_shape=(args.num_samples, ))['x'].copy() print("\nStart NeuTra HMC...") neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(dual_moon_model) nuts_kernel = NUTS(neutra_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(3)) mcmc.print_summary() zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"] print("Transform samples into unwarped space...") samples = neutra.transform_sample(zs) print_summary(samples) zs = zs.reshape(-1, 2) samples = samples['x'].reshape(-1, 2).copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(jnp.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = neutra.transform_sample(guide_base_samples)['x'] x1 = jnp.linspace(-3, 3, 100) x2 = jnp.linspace(-3, 3, 100) X1, X2 = jnp.meshgrid(x1, x2) P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 8), constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 0]) ax3 = fig.add_subplot(gs[0, 1]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[0, 2]) ax6 = fig.add_subplot(gs[1, 2]) ax1.plot(losses[1000:]) ax1.set_title('Autoguide training loss\n(after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)' ) ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler') plt.savefig("neutra.pdf")
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() adam = optim.Adam(0.01) # TODO: it is hard to find good hyperparameters such that IAF guide can learn this model. # We will use BNAF instead! guide = AutoIAFNormal(dual_moon_model, num_flows=2, hidden_dims=[args.num_hidden, args.num_hidden]) svi = SVI(dual_moon_model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(1)) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(0), params, sample_shape=(args.num_samples, ))['x'].copy() transform = guide.get_transform(params) _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), dual_moon_model) transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform) transformed_constrain_fn = lambda x: constrain_fn(transform(x) ) # noqa: E731 print("\nStart NeuTra HMC...") nuts_kernel = NUTS(potential_fn=transformed_potential_fn) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) init_params = np.zeros(guide.latent_size) mcmc.run(random.PRNGKey(3), init_params=init_params) mcmc.print_summary() zs = mcmc.get_samples() print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) print_summary(tree_map(lambda x: x[None, ...], samples)) samples = samples['x'].copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = vmap(transformed_constrain_fn)( guide_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 16), constrained_layout=True) gs = GridSpec(3, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) ax1.plot(np.log(losses[1000:])) ax1.set_title('Autoguide training log loss (after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') plt.savefig("neutra.pdf") plt.close()
def print_summary(self, prob=0.9): print_summary(self._states['z'], prob=prob) extra_fields = self.get_extra_fields() if 'diverging' in extra_fields: print("Number of divergences: {}".format( np.sum(extra_fields['diverging'])))
svi = SVI(model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values) p5_3, losses = svi.run(random.PRNGKey(0), 1000) post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (1000, )) # Posterior param_names = {'a', 'bA', 'bM', 'sigma'} for p in param_names: print(f'posterior for {p}') print_summary(post[p], 0.95, False) # PPC # call predictive without specifying new data # so it uses original data post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4), )) post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values) mu = post_pred["mu"] # summarize samples across cases mu_mean = jnp.mean(mu, 0) mu_PI = jnp.percentile(mu, q=(5.5, 94.5), axis=0)