示例#1
0
def test_initialize_model_change_point(init_strategy):
    def model(data):
        alpha = 1 / jnp.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = jnp.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])

    rng_keys = random.split(random.PRNGKey(1), 2)
    init_params, _, _, _ = initialize_model(rng_keys, model,
                                            init_strategy=init_strategy,
                                            model_args=(count_data,))
    if isinstance(init_strategy, partial) and init_strategy.func is init_to_value:
        expected = biject_to(constraints.unit_interval).inv(init_strategy.keywords.get('values')['tau'])
        assert_allclose(init_params[0]['tau'], jnp.repeat(expected, 2))
    for i in range(2):
        init_params_i, _, _, _ = initialize_model(rng_keys[i], model,
                                                  init_strategy=init_strategy,
                                                  model_args=(count_data,))
        for name, p in init_params[0].items():
            # XXX: the result is equal if we disable fast-math-mode
            assert_allclose(p[i], init_params_i[0][name], atol=1e-6)
示例#2
0
def test_initialize_model_change_point(init_strategy):
    def model(data):
        alpha = 1 / np.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = np.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])

    rng_keys = random.split(random.PRNGKey(1), 2)
    init_params, _, _ = initialize_model(rng_keys, model, count_data,
                                         init_strategy=init_strategy)
    for i in range(2):
        init_params_i, _, _ = initialize_model(rng_keys[i], model, count_data,
                                               init_strategy=init_strategy)
        for name, p in init_params.items():
            # XXX: the result is equal if we disable fast-math-mode
            assert_allclose(p[i], init_params_i[name], atol=1e-6)
示例#3
0
def test_reparam_log_joint(model, kwargs):
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, Adam(1e-10), Trace_ELBO(), **kwargs)
    svi_state = svi.init(random.PRNGKey(0))
    params = svi.get_params(svi_state)
    neutra = NeuTraReparam(guide, params)
    reparam_model = neutra.reparam(model)
    _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs)
    init_params, pe_fn_neutra, _, _ = initialize_model(random.PRNGKey(2), reparam_model, model_kwargs=kwargs)
    latent_x = list(init_params[0].values())[0]
    pe_transformed = pe_fn_neutra(init_params[0])
    latent_y = neutra.transform(latent_x)
    log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y)
    pe = pe_fn(guide._unpack_latent(latent_y))
    assert_allclose(pe_transformed, pe - log_det_jacobian)
示例#4
0
def test_reuse_mcmc_pe_gen():
    y1 = onp.random.normal(3, 0.1, (100, ))
    y2 = onp.random.normal(-3, 0.1, (100, ))

    def model(y_obs):
        mu = numpyro.sample('mu', dist.Normal(0., 1.))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)

    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(0), model, y1, dynamic_args=True)
    init_kernel, sample_kernel = hmc(potential_fn_gen=potential_fn)
    init_state = init_kernel(init_params, num_warmup=300, model_args=(y1, ))

    @jit
    def _sample(state_and_args):
        hmc_state, model_args = state_and_args
        return sample_kernel(hmc_state, (model_args, )), model_args

    samples = fori_collect(0,
                           500,
                           _sample, (init_state, y1),
                           transform=lambda state: constrain_fn(y1)
                           (state[0].z))
    assert_allclose(samples['mu'].mean(), 3., atol=0.1)

    # Run on data, re-using `mcmc` - this should be much faster.
    init_state = init_kernel(init_params, num_warmup=300, model_args=(y2, ))
    samples = fori_collect(0,
                           500,
                           _sample, (init_state, y2),
                           transform=lambda state: constrain_fn(y2)
                           (state[0].z))
    assert_allclose(samples['mu'].mean(), -3., atol=0.1)
示例#5
0
文件: hmc.py 项目: gully/numpyro
    def _init_state(self, rng_key, model_args, model_kwargs, init_params):
        if self._model is not None:
            init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
                rng_key,
                self._model,
                dynamic_args=True,
                init_strategy=self._init_strategy,
                model_args=model_args,
                model_kwargs=model_kwargs)
            if any(v['type'] == 'param' for v in model_trace.values()):
                warnings.warn("'param' sites will be treated as constants during inference. To define "
                              "an improper variable, please use a 'sample' site with log probability "
                              "masked out. For example, `sample('x', dist.LogNormal(0, 1).mask(False)` "
                              "means that `x` has improper distribution over the positive domain.")
            if self._init_fn is None:
                self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn,
                                                     kinetic_fn=self._kinetic_fn,
                                                     algo=self._algo)
            self._postprocess_fn = postprocess_fn
        elif self._init_fn is None:
            self._init_fn, self._sample_fn = hmc(potential_fn=self._potential_fn,
                                                 kinetic_fn=self._kinetic_fn,
                                                 algo=self._algo)

        return init_params
示例#6
0
 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
     if self._model is not None:
         init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
             rng_key,
             self._model,
             init_strategy=self._init_strategy,
             dynamic_args=True,
             model_args=model_args,
             model_kwargs=model_kwargs)
         init_params = init_params.z
         if self._init_fn is None:
             _, unravel_fn = ravel_pytree(init_params)
             kernel = self.kernel_class(
                 _make_log_prob_fn(
                     potential_fn(*model_args, **model_kwargs), unravel_fn),
                 **self._kernel_kwargs)
             # Uncalibrated... kernels have to used inside MetropolisHastings, see
             # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/mcmc/UncalibratedLangevin
             if self.kernel_class.__name__.startswith("Uncalibrated"):
                 kernel = tfp.mcmc.MetropolisHastings(kernel)
             self._init_fn, self._sample_fn = _extract_kernel_functions(
                 kernel)
         self._postprocess_fn = postprocess_fn
     elif self._init_fn is None:
         _, unravel_fn = ravel_pytree(init_params)
         kernel = self.kernel_class(
             _make_log_prob_fn(self._potential_fn, unravel_fn),
             **self._kernel_kwargs)
         if self.kernel_class.__name__.startswith("Uncalibrated"):
             kernel = tfp.mcmc.MetropolisHastings(kernel)
         self._init_fn, self._sample_fn = _extract_kernel_functions(kernel)
     return init_params
示例#7
0
    def _init_state(self, rng_key, model_args, model_kwargs, init_params):
        if self._model is not None:
            init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
                rng_key,
                self._model,
                dynamic_args=True,
                init_strategy=self._init_strategy,
                model_args=model_args,
                model_kwargs=model_kwargs,
                forward_mode_differentiation=self.
                _forward_mode_differentiation,
            )
            if self._init_fn is None:
                self._init_fn, self._sample_fn = hmc(
                    potential_fn_gen=potential_fn,
                    kinetic_fn=self._kinetic_fn,
                    algo=self._algo,
                )
            self._potential_fn_gen = potential_fn
            self._postprocess_fn = postprocess_fn
        elif self._init_fn is None:
            self._init_fn, self._sample_fn = hmc(
                potential_fn=self._potential_fn,
                kinetic_fn=self._kinetic_fn,
                algo=self._algo,
            )

        return init_params
示例#8
0
    def _setup_prototype(self, *args, **kwargs):
        rng_key = numpyro.prng_key()
        with handlers.block():
            (
                init_params,
                _,
                self._postprocess_fn,
                self.prototype_trace,
            ) = initialize_model(
                rng_key,
                self.model,
                init_strategy=self.init_loc_fn,
                dynamic_args=False,
                model_args=args,
                model_kwargs=kwargs,
            )
        self._init_locs = init_params[0]

        self._prototype_frames = {}
        self._prototype_plate_sizes = {}
        for name, site in self.prototype_trace.items():
            if site["type"] == "sample":
                for frame in site["cond_indep_stack"]:
                    self._prototype_frames[frame.name] = frame
            elif site["type"] == "plate":
                self._prototype_frame_full_sizes[name] = site["args"][0]
示例#9
0
def main(args):
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
    model_info = initialize_model(init_rng_key, model, model_args=(returns,))
    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
    hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key)
    hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state,
                              transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z),
                              progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    print_results(hmc_states, dates)

    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
    dates = mdates.num2date(mdates.datestr2num(dates))
    ax.plot(dates, returns, lw=0.5)
    # format the ticks
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax.xaxis.set_minor_locator(mdates.MonthLocator())

    ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01)
    legend = ax.legend(['returns', 'volatility'], loc='upper right')
    legend.legendHandles[1].set_alpha(0.6)
    ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time')

    plt.savefig("stochastic_volatility_plot.pdf")
示例#10
0
def test_functional_beta_bernoulli_x64(algo):
    warmup_steps, num_samples = 500, 20000

    def model(data):
        alpha = np.array([1.1, 1.1])
        beta = np.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, model_args=(data, ))
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.,
                            num_warmup=warmup_steps)
    samples = fori_collect(0,
                           num_samples,
                           sample_kernel,
                           hmc_state,
                           transform=lambda x: constrain_fn(x.z))
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
示例#11
0
def test_initialize_model_dirichlet_categorical(init_strategy):
    def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))

    rng_keys = random.split(random.PRNGKey(1), 2)
    init_params, _, _ = initialize_model(rng_keys, model, data,
                                         init_strategy=init_strategy)
    for i in range(2):
        init_params_i, _, _ = initialize_model(rng_keys[i], model, data,
                                               init_strategy=init_strategy)
        for name, p in init_params.items():
            # XXX: the result is equal if we disable fast-math-mode
            assert_allclose(p[i], init_params_i[name], atol=1e-6)
示例#12
0
 def init(self,
          rng_key,
          num_warmup,
          init_params=None,
          model_args=(),
          model_kwargs={}):
     constrain_fn = None
     if self.model is not None:
         if rng_key.ndim == 1:
             rng_key, rng_key_init_model = random.split(rng_key)
         else:
             rng_key, rng_key_init_model = np.swapaxes(
                 vmap(random.split)(rng_key), 0, 1)
         init_params_, self.potential_fn, constrain_fn = initialize_model(
             rng_key_init_model,
             self.model,
             *model_args,
             init_strategy=self.init_strategy,
             **model_kwargs)
         if init_params is None:
             init_params = init_params_
     else:
         # User needs to provide valid `init_params` if using `potential_fn`.
         if init_params is None:
             raise ValueError(
                 'Valid value of `init_params` must be provided with'
                 ' `potential_fn`.')
     hmc_init, sample_fn = hmc(self.potential_fn,
                               self.kinetic_fn,
                               algo=self.algo)
     hmc_init_fn = lambda init_params, rng_key: hmc_init(  # noqa: E731
         init_params,
         num_warmup=num_warmup,
         step_size=self.step_size,
         adapt_step_size=self.adapt_step_size,
         adapt_mass_matrix=self.adapt_mass_matrix,
         dense_mass=self.dense_mass,
         target_accept_prob=self.target_accept_prob,
         trajectory_length=self.trajectory_length,
         max_tree_depth=self.max_tree_depth,
         run_warmup=False,
         rng_key=rng_key,
     )
     if rng_key.ndim == 1:
         init_state = hmc_init_fn(init_params, rng_key)
         self._sample_fn = sample_fn
     else:
         # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
         # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
         # wa_steps because those variables do not depend on traced args: init_params, rng_key.
         init_state = vmap(hmc_init_fn)(init_params, rng_key)
         self._sample_fn = vmap(sample_fn)
     return init_state, constrain_fn
示例#13
0
def test_construct():
    import bellini
    from bellini import Quantity, Species, Substance, Story
    import pint
    ureg = pint.UnitRegistry()

    s = Story()
    water = Species(name='water')
    s.one_water_quantity = bellini.distributions.Normal(
        loc=Quantity(3.0, ureg.mole),
        scale=Quantity(0.01, ureg.mole),
        name="first_normal",
    )
    s.another_water_quantity = bellini.distributions.Normal(
        loc=Quantity(3.0, ureg.mole),
        scale=Quantity(0.01, ureg.mole),
        name="second_normal",
    )
    s.combined_water = s.one_water_quantity + s.another_water_quantity
    # s.combined_water.observed = True
    s.combined_water.name = "combined_water"

    s.combined_water_with_nose = bellini.distributions.Normal(
        loc=s.combined_water,
        scale=Quantity(0.01, ureg.mole),
        name="combined_with_noise")

    s.combined_water_with_nose.observed = True

    from bellini.api._numpyro import graph_to_numpyro_model
    model = graph_to_numpyro_model(s.g)

    from numpyro.infer.util import initialize_model
    import jax
    model_info = initialize_model(
        jax.random.PRNGKey(2666),
        model,
    )
    from numpyro.infer.hmc import hmc
    from numpyro.util import fori_collect

    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
    hmc_state = init_kernel(model_info.param_info,
                            trajectory_length=10,
                            num_warmup=300)
    samples = fori_collect(
        0,
        500,
        sample_kernel,
        hmc_state,
        transform=lambda state: model_info.postprocess_fn(state.z))

    print(samples)
示例#14
0
def test_improper_expand(event_shape):
    def model():
        population = jnp.array([1000., 2000., 3000.])
        with numpyro.plate("region", 3):
            d = dist.ImproperUniform(support=constraints.interval(
                0, population),
                                     batch_shape=(3, ),
                                     event_shape=event_shape)
            incidence = numpyro.sample("incidence", d)
            assert d.log_prob(incidence).shape == (3, )

    model_info = initialize_model(random.PRNGKey(0), model)
    assert model_info.param_info.z['incidence'].shape == (3, ) + event_shape
示例#15
0
 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
     if self._model is not None:
         params_info, potential_fn_gen, self._postprocess_fn, model_trace = initialize_model(
             rng_key,
             self._model,
             dynamic_args=True,
             init_strategy=self._init_strategy,
             model_args=model_args,
             model_kwargs=model_kwargs)
         init_params = params_info[0]
         model_kwargs = {} if model_kwargs is None else model_kwargs
         self._potential_fn = potential_fn_gen(*model_args, **model_kwargs)
     return init_params
示例#16
0
def main(args):
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
    init_params, potential_fn, constrain_fn = initialize_model(
        init_rng_key, model, returns)
    init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
    hmc_state = init_kernel(init_params,
                            args.num_warmup,
                            rng_key=sample_rng_key)
    hmc_states = fori_collect(
        0,
        args.num_samples,
        sample_kernel,
        hmc_state,
        transform=lambda hmc_state: constrain_fn(hmc_state.z))
    print_results(hmc_states, dates)
示例#17
0
    def init(self,
             rng_key,
             num_warmup,
             init_params=None,
             model_args=(),
             model_kwargs={}):
        # non-vectorized
        if rng_key.ndim == 1:
            rng_key, rng_key_init_model = random.split(rng_key)
        # vectorized
        else:
            rng_key, rng_key_init_model = jnp.swapaxes(
                vmap(random.split)(rng_key), 0, 1)

        # If supplied with a model, then there is a function to get most of the "stuff"
        if self._model is not None:
            init_params, model_potential_fn, postprocess_fn, model_trace = initialize_model(
                rng_key,
                self._model,
                dynamic_args=True,
                init_strategy=self._init_strategy,
                model_args=model_args,
                model_kwargs=model_kwargs)
            # use the keyword arguments for the model to build the potential function
            kwargs = {} if model_kwargs is None else model_kwargs
            self._potential_fn = model_potential_fn(*model_args, **kwargs)
            self._postprocess_fn = postprocess_fn

        if self._potential_fn and init_params is None:
            raise ValueError(
                'Valid value of `init_params` must be provided with'
                ' `potential_fn`.')

        # init state
        if isinstance(init_params, ParamInfo):
            z, pe, z_grad = init_params
        else:
            z, pe, z_grad = init_params, None, None
        pe, z_grad = value_and_grad(self._potential_fn)(z)

        # init preconditioner
        self._preconditioner = Preconditioner(z, self._covar, self._covar_inv)
        self._dimension = self._preconditioner._dimension

        init_state = HState(0, z, pe, z_grad, 0, 0.0, rng_key)
        return device_put(init_state)
示例#18
0
    def _setup_prototype(self, *args, **kwargs):
        rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity())
        with handlers.block():
            init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model(
                rng_key, self.model,
                init_strategy=self.init_strategy,
                dynamic_args=False,
                model_args=args,
                model_kwargs=kwargs)

        self._init_latent, unpack_latent = ravel_pytree(init_params[0])
        # this is to match the behavior of Pyro, where we can apply
        # unpack_latent for a batch of samples
        self._unpack_latent = UnpackTransform(unpack_latent)
        self.latent_dim = jnp.size(self._init_latent)
        if self.latent_dim == 0:
            raise RuntimeError('{} found no latent variables; Use an empty guide instead'
                               .format(type(self).__name__))
示例#19
0
    def _init_state(self, rng_key, model_args, model_kwargs, init_params):
        if self._model is not None:
            init_params, potential_fn, postprocess_fn, _ = initialize_model(
                rng_key,
                self._model,
                dynamic_args=True,
                model_args=model_args,
                model_kwargs=model_kwargs)
            init_params = init_params[0]
            # NB: init args is different from HMC
            self._init_fn, sample_fn = _sa(potential_fn_gen=potential_fn)
            if self._postprocess_fn is None:
                self._postprocess_fn = postprocess_fn
        else:
            self._init_fn, sample_fn = _sa(potential_fn=self._potential_fn)

        if self._sample_fn is None:
            self._sample_fn = sample_fn
        return init_params
示例#20
0
    theta2 = jnp.exp(alpha + attack[away_id] - defend[home_id])

    with numpyro.plate("data", len(home_id)):
        numpyro.sample("s1", dist.Poisson(theta1), obs=score1_obs)
        numpyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)


rng_key = random.PRNGKey(2)

# translate the model into a log-probability function
init_params, potential_fn_gen, *_ = initialize_model(
    rng_key,
    model,
    model_args=(
        train["Home_id"].values,
        train["Away_id"].values,
        train["score1"].values,
        train["score2"].values,
    ),
    dynamic_args=True,
)

# logprob = lambda position: -potential_fn_gen(
#     train["Home_id"].values,
#     train["Away_id"].values,
#     train["score1"].values,
#     train["score2"].values,
# )(position)

# initial_position = init_params.z
# initial_state = nuts.new_state(initial_position, logprob)
示例#21
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    guide = AutoBNAFNormal(
        dual_moon_model,
        hidden_factors=[args.hidden_factor, args.hidden_factor])
    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(0), params,
        sample_shape=(args.num_samples, ))['x'].copy()

    transform = guide.get_transform(params)
    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2),
                                                     dual_moon_model)
    transformed_potential_fn = partial(transformed_potential_energy,
                                       potential_fn, transform)
    transformed_constrain_fn = lambda x: constrain_fn(transform(x)
                                                      )  # noqa: E731

    print("\nStart NeuTra HMC...")
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    init_params = np.zeros(guide.latent_size)
    mcmc.run(random.PRNGKey(3), init_params=init_params)
    mcmc.print_summary()
    zs = mcmc.get_samples()
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    print_summary(tree_map(lambda x: x[None, ...], samples))
    samples = samples['x'].copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(np.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = vmap(transformed_constrain_fn)(
        guide_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[0, 1])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[0, 2])
    ax6 = fig.add_subplot(gs[1, 2])

    ax1.plot(losses[1000:])
    ax1.set_title('Autoguide training loss\n(after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nAutoBNAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)'
    )

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nvanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the\nwarped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nNeuTra HMC sampler')

    plt.savefig("neutra.pdf")