def test_initialize_model_change_point(init_strategy): def model(data): alpha = 1 / jnp.mean(data) lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha)) lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha)) tau = numpyro.sample('tau', dist.Uniform(0, 1)) lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample('obs', dist.Poisson(lambda12), obs=data) count_data = jnp.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) rng_keys = random.split(random.PRNGKey(1), 2) init_params, _, _, _ = initialize_model(rng_keys, model, init_strategy=init_strategy, model_args=(count_data,)) if isinstance(init_strategy, partial) and init_strategy.func is init_to_value: expected = biject_to(constraints.unit_interval).inv(init_strategy.keywords.get('values')['tau']) assert_allclose(init_params[0]['tau'], jnp.repeat(expected, 2)) for i in range(2): init_params_i, _, _, _ = initialize_model(rng_keys[i], model, init_strategy=init_strategy, model_args=(count_data,)) for name, p in init_params[0].items(): # XXX: the result is equal if we disable fast-math-mode assert_allclose(p[i], init_params_i[0][name], atol=1e-6)
def test_initialize_model_change_point(init_strategy): def model(data): alpha = 1 / np.mean(data) lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha)) lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha)) tau = numpyro.sample('tau', dist.Uniform(0, 1)) lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample('obs', dist.Poisson(lambda12), obs=data) count_data = np.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) rng_keys = random.split(random.PRNGKey(1), 2) init_params, _, _ = initialize_model(rng_keys, model, count_data, init_strategy=init_strategy) for i in range(2): init_params_i, _, _ = initialize_model(rng_keys[i], model, count_data, init_strategy=init_strategy) for name, p in init_params.items(): # XXX: the result is equal if we disable fast-math-mode assert_allclose(p[i], init_params_i[name], atol=1e-6)
def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-10), Trace_ELBO(), **kwargs) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model(random.PRNGKey(2), reparam_model, model_kwargs=kwargs) latent_x = list(init_params[0].values())[0] pe_transformed = pe_fn_neutra(init_params[0]) latent_y = neutra.transform(latent_x) log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn(guide._unpack_latent(latent_y)) assert_allclose(pe_transformed, pe - log_det_jacobian)
def test_reuse_mcmc_pe_gen(): y1 = onp.random.normal(3, 0.1, (100, )) y2 = onp.random.normal(-3, 0.1, (100, )) def model(y_obs): mu = numpyro.sample('mu', dist.Normal(0., 1.)) sigma = numpyro.sample("sigma", dist.HalfCauchy(3.)) numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(0), model, y1, dynamic_args=True) init_kernel, sample_kernel = hmc(potential_fn_gen=potential_fn) init_state = init_kernel(init_params, num_warmup=300, model_args=(y1, )) @jit def _sample(state_and_args): hmc_state, model_args = state_and_args return sample_kernel(hmc_state, (model_args, )), model_args samples = fori_collect(0, 500, _sample, (init_state, y1), transform=lambda state: constrain_fn(y1) (state[0].z)) assert_allclose(samples['mu'].mean(), 3., atol=0.1) # Run on data, re-using `mcmc` - this should be much faster. init_state = init_kernel(init_params, num_warmup=300, model_args=(y2, )) samples = fori_collect(0, 500, _sample, (init_state, y2), transform=lambda state: constrain_fn(y2) (state[0].z)) assert_allclose(samples['mu'].mean(), -3., atol=0.1)
def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, self._model, dynamic_args=True, init_strategy=self._init_strategy, model_args=model_args, model_kwargs=model_kwargs) if any(v['type'] == 'param' for v in model_trace.values()): warnings.warn("'param' sites will be treated as constants during inference. To define " "an improper variable, please use a 'sample' site with log probability " "masked out. For example, `sample('x', dist.LogNormal(0, 1).mask(False)` " "means that `x` has improper distribution over the positive domain.") if self._init_fn is None: self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn, kinetic_fn=self._kinetic_fn, algo=self._algo) self._postprocess_fn = postprocess_fn elif self._init_fn is None: self._init_fn, self._sample_fn = hmc(potential_fn=self._potential_fn, kinetic_fn=self._kinetic_fn, algo=self._algo) return init_params
def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, self._model, init_strategy=self._init_strategy, dynamic_args=True, model_args=model_args, model_kwargs=model_kwargs) init_params = init_params.z if self._init_fn is None: _, unravel_fn = ravel_pytree(init_params) kernel = self.kernel_class( _make_log_prob_fn( potential_fn(*model_args, **model_kwargs), unravel_fn), **self._kernel_kwargs) # Uncalibrated... kernels have to used inside MetropolisHastings, see # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/mcmc/UncalibratedLangevin if self.kernel_class.__name__.startswith("Uncalibrated"): kernel = tfp.mcmc.MetropolisHastings(kernel) self._init_fn, self._sample_fn = _extract_kernel_functions( kernel) self._postprocess_fn = postprocess_fn elif self._init_fn is None: _, unravel_fn = ravel_pytree(init_params) kernel = self.kernel_class( _make_log_prob_fn(self._potential_fn, unravel_fn), **self._kernel_kwargs) if self.kernel_class.__name__.startswith("Uncalibrated"): kernel = tfp.mcmc.MetropolisHastings(kernel) self._init_fn, self._sample_fn = _extract_kernel_functions(kernel) return init_params
def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, self._model, dynamic_args=True, init_strategy=self._init_strategy, model_args=model_args, model_kwargs=model_kwargs, forward_mode_differentiation=self. _forward_mode_differentiation, ) if self._init_fn is None: self._init_fn, self._sample_fn = hmc( potential_fn_gen=potential_fn, kinetic_fn=self._kinetic_fn, algo=self._algo, ) self._potential_fn_gen = potential_fn self._postprocess_fn = postprocess_fn elif self._init_fn is None: self._init_fn, self._sample_fn = hmc( potential_fn=self._potential_fn, kinetic_fn=self._kinetic_fn, algo=self._algo, ) return init_params
def _setup_prototype(self, *args, **kwargs): rng_key = numpyro.prng_key() with handlers.block(): ( init_params, _, self._postprocess_fn, self.prototype_trace, ) = initialize_model( rng_key, self.model, init_strategy=self.init_loc_fn, dynamic_args=False, model_args=args, model_kwargs=kwargs, ) self._init_locs = init_params[0] self._prototype_frames = {} self._prototype_plate_sizes = {} for name, site in self.prototype_trace.items(): if site["type"] == "sample": for frame in site["cond_indep_stack"]: self._prototype_frames[frame.name] = frame elif site["type"] == "plate": self._prototype_frame_full_sizes[name] = site["args"][0]
def main(args): _, fetch = load_dataset(SP500, shuffle=False) dates, returns = fetch() init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed)) model_info = initialize_model(init_rng_key, model, model_args=(returns,)) init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key) hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z), progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) print_results(hmc_states, dates) fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) dates = mdates.num2date(mdates.datestr2num(dates)) ax.plot(dates, returns, lw=0.5) # format the ticks ax.xaxis.set_major_locator(mdates.YearLocator()) ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y')) ax.xaxis.set_minor_locator(mdates.MonthLocator()) ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01) legend = ax.legend(['returns', 'volatility'], loc='upper right') legend.legendHandles[1].set_alpha(0.6) ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time') plt.savefig("stochastic_volatility_plot.pdf")
def test_functional_beta_bernoulli_x64(algo): warmup_steps, num_samples = 500, 20000 def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, model_args=(data, )) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1., num_warmup=warmup_steps) samples = fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z)) assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert samples['p_latent'].dtype == np.float64
def test_initialize_model_dirichlet_categorical(init_strategy): def model(data): concentration = np.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration)) numpyro.sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = np.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,)) rng_keys = random.split(random.PRNGKey(1), 2) init_params, _, _ = initialize_model(rng_keys, model, data, init_strategy=init_strategy) for i in range(2): init_params_i, _, _ = initialize_model(rng_keys[i], model, data, init_strategy=init_strategy) for name, p in init_params.items(): # XXX: the result is equal if we disable fast-math-mode assert_allclose(p[i], init_params_i[name], atol=1e-6)
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): constrain_fn = None if self.model is not None: if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) else: rng_key, rng_key_init_model = np.swapaxes( vmap(random.split)(rng_key), 0, 1) init_params_, self.potential_fn, constrain_fn = initialize_model( rng_key_init_model, self.model, *model_args, init_strategy=self.init_strategy, **model_kwargs) if init_params is None: init_params = init_params_ else: # User needs to provide valid `init_params` if using `potential_fn`. if init_params is None: raise ValueError( 'Valid value of `init_params` must be provided with' ' `potential_fn`.') hmc_init, sample_fn = hmc(self.potential_fn, self.kinetic_fn, algo=self.algo) hmc_init_fn = lambda init_params, rng_key: hmc_init( # noqa: E731 init_params, num_warmup=num_warmup, step_size=self.step_size, adapt_step_size=self.adapt_step_size, adapt_mass_matrix=self.adapt_mass_matrix, dense_mass=self.dense_mass, target_accept_prob=self.target_accept_prob, trajectory_length=self.trajectory_length, max_tree_depth=self.max_tree_depth, run_warmup=False, rng_key=rng_key, ) if rng_key.ndim == 1: init_state = hmc_init_fn(init_params, rng_key) self._sample_fn = sample_fn else: # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) self._sample_fn = vmap(sample_fn) return init_state, constrain_fn
def test_construct(): import bellini from bellini import Quantity, Species, Substance, Story import pint ureg = pint.UnitRegistry() s = Story() water = Species(name='water') s.one_water_quantity = bellini.distributions.Normal( loc=Quantity(3.0, ureg.mole), scale=Quantity(0.01, ureg.mole), name="first_normal", ) s.another_water_quantity = bellini.distributions.Normal( loc=Quantity(3.0, ureg.mole), scale=Quantity(0.01, ureg.mole), name="second_normal", ) s.combined_water = s.one_water_quantity + s.another_water_quantity # s.combined_water.observed = True s.combined_water.name = "combined_water" s.combined_water_with_nose = bellini.distributions.Normal( loc=s.combined_water, scale=Quantity(0.01, ureg.mole), name="combined_with_noise") s.combined_water_with_nose.observed = True from bellini.api._numpyro import graph_to_numpyro_model model = graph_to_numpyro_model(s.g) from numpyro.infer.util import initialize_model import jax model_info = initialize_model( jax.random.PRNGKey(2666), model, ) from numpyro.infer.hmc import hmc from numpyro.util import fori_collect init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') hmc_state = init_kernel(model_info.param_info, trajectory_length=10, num_warmup=300) samples = fori_collect( 0, 500, sample_kernel, hmc_state, transform=lambda state: model_info.postprocess_fn(state.z)) print(samples)
def test_improper_expand(event_shape): def model(): population = jnp.array([1000., 2000., 3000.]) with numpyro.plate("region", 3): d = dist.ImproperUniform(support=constraints.interval( 0, population), batch_shape=(3, ), event_shape=event_shape) incidence = numpyro.sample("incidence", d) assert d.log_prob(incidence).shape == (3, ) model_info = initialize_model(random.PRNGKey(0), model) assert model_info.param_info.z['incidence'].shape == (3, ) + event_shape
def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self._model is not None: params_info, potential_fn_gen, self._postprocess_fn, model_trace = initialize_model( rng_key, self._model, dynamic_args=True, init_strategy=self._init_strategy, model_args=model_args, model_kwargs=model_kwargs) init_params = params_info[0] model_kwargs = {} if model_kwargs is None else model_kwargs self._potential_fn = potential_fn_gen(*model_args, **model_kwargs) return init_params
def main(args): _, fetch = load_dataset(SP500, shuffle=False) dates, returns = fetch() init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed)) init_params, potential_fn, constrain_fn = initialize_model( init_rng_key, model, returns) init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') hmc_state = init_kernel(init_params, args.num_warmup, rng_key=sample_rng_key) hmc_states = fori_collect( 0, args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: constrain_fn(hmc_state.z)) print_results(hmc_states, dates)
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): # non-vectorized if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: rng_key, rng_key_init_model = jnp.swapaxes( vmap(random.split)(rng_key), 0, 1) # If supplied with a model, then there is a function to get most of the "stuff" if self._model is not None: init_params, model_potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, self._model, dynamic_args=True, init_strategy=self._init_strategy, model_args=model_args, model_kwargs=model_kwargs) # use the keyword arguments for the model to build the potential function kwargs = {} if model_kwargs is None else model_kwargs self._potential_fn = model_potential_fn(*model_args, **kwargs) self._postprocess_fn = postprocess_fn if self._potential_fn and init_params is None: raise ValueError( 'Valid value of `init_params` must be provided with' ' `potential_fn`.') # init state if isinstance(init_params, ParamInfo): z, pe, z_grad = init_params else: z, pe, z_grad = init_params, None, None pe, z_grad = value_and_grad(self._potential_fn)(z) # init preconditioner self._preconditioner = Preconditioner(z, self._covar, self._covar_inv) self._dimension = self._preconditioner._dimension init_state = HState(0, z, pe, z_grad, 0, 0.0, rng_key) return device_put(init_state)
def _setup_prototype(self, *args, **kwargs): rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity()) with handlers.block(): init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model( rng_key, self.model, init_strategy=self.init_strategy, dynamic_args=False, model_args=args, model_kwargs=kwargs) self._init_latent, unpack_latent = ravel_pytree(init_params[0]) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples self._unpack_latent = UnpackTransform(unpack_latent) self.latent_dim = jnp.size(self._init_latent) if self.latent_dim == 0: raise RuntimeError('{} found no latent variables; Use an empty guide instead' .format(type(self).__name__))
def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self._model is not None: init_params, potential_fn, postprocess_fn, _ = initialize_model( rng_key, self._model, dynamic_args=True, model_args=model_args, model_kwargs=model_kwargs) init_params = init_params[0] # NB: init args is different from HMC self._init_fn, sample_fn = _sa(potential_fn_gen=potential_fn) if self._postprocess_fn is None: self._postprocess_fn = postprocess_fn else: self._init_fn, sample_fn = _sa(potential_fn=self._potential_fn) if self._sample_fn is None: self._sample_fn = sample_fn return init_params
theta2 = jnp.exp(alpha + attack[away_id] - defend[home_id]) with numpyro.plate("data", len(home_id)): numpyro.sample("s1", dist.Poisson(theta1), obs=score1_obs) numpyro.sample("s2", dist.Poisson(theta2), obs=score2_obs) rng_key = random.PRNGKey(2) # translate the model into a log-probability function init_params, potential_fn_gen, *_ = initialize_model( rng_key, model, model_args=( train["Home_id"].values, train["Away_id"].values, train["score1"].values, train["score2"].values, ), dynamic_args=True, ) # logprob = lambda position: -potential_fn_gen( # train["Home_id"].values, # train["Away_id"].values, # train["score1"].values, # train["score2"].values, # )(position) # initial_position = init_params.z # initial_state = nuts.new_state(initial_position, logprob)
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() guide = AutoBNAFNormal( dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]) svi = SVI(dual_moon_model, guide, optim.Adam(0.003), AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(1)) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(0), params, sample_shape=(args.num_samples, ))['x'].copy() transform = guide.get_transform(params) _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), dual_moon_model) transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform) transformed_constrain_fn = lambda x: constrain_fn(transform(x) ) # noqa: E731 print("\nStart NeuTra HMC...") nuts_kernel = NUTS(potential_fn=transformed_potential_fn) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) init_params = np.zeros(guide.latent_size) mcmc.run(random.PRNGKey(3), init_params=init_params) mcmc.print_summary() zs = mcmc.get_samples() print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) print_summary(tree_map(lambda x: x[None, ...], samples)) samples = samples['x'].copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = vmap(transformed_constrain_fn)( guide_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 8), constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 0]) ax3 = fig.add_subplot(gs[0, 1]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[0, 2]) ax6 = fig.add_subplot(gs[1, 2]) ax1.plot(losses[1000:]) ax1.set_title('Autoguide training loss\n(after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)' ) ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler') plt.savefig("neutra.pdf")