Ejemplo n.º 1
0
def test_reuse_mcmc_pe_gen():
    y1 = onp.random.normal(3, 0.1, (100, ))
    y2 = onp.random.normal(-3, 0.1, (100, ))

    def model(y_obs):
        mu = numpyro.sample('mu', dist.Normal(0., 1.))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)

    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(0), model, y1, dynamic_args=True)
    init_kernel, sample_kernel = hmc(potential_fn_gen=potential_fn)
    init_state = init_kernel(init_params, num_warmup=300, model_args=(y1, ))

    @jit
    def _sample(state_and_args):
        hmc_state, model_args = state_and_args
        return sample_kernel(hmc_state, (model_args, )), model_args

    samples = fori_collect(0,
                           500,
                           _sample, (init_state, y1),
                           transform=lambda state: constrain_fn(y1)
                           (state[0].z))
    assert_allclose(samples['mu'].mean(), 3., atol=0.1)

    # Run on data, re-using `mcmc` - this should be much faster.
    init_state = init_kernel(init_params, num_warmup=300, model_args=(y2, ))
    samples = fori_collect(0,
                           500,
                           _sample, (init_state, y2),
                           transform=lambda state: constrain_fn(y2)
                           (state[0].z))
    assert_allclose(samples['mu'].mean(), -3., atol=0.1)
Ejemplo n.º 2
0
def test_functional_beta_bernoulli_x64(algo):
    warmup_steps, num_samples = 500, 20000

    def model(data):
        alpha = np.array([1.1, 1.1])
        beta = np.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, model_args=(data, ))
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.,
                            num_warmup=warmup_steps)
    samples = fori_collect(0,
                           num_samples,
                           sample_kernel,
                           hmc_state,
                           transform=lambda x: constrain_fn(x.z))
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
Ejemplo n.º 3
0
    def run(
        self,
        rng_key,
        num_steps,
        *args,
        progress_bar=True,
        init_state=None,
        collect_fn=lambda val: val[1],  # TODO: refactor
        **kwargs,
    ):
        def bodyfn(_i, info):
            body_state = info[0]
            return (*self.update(body_state, *info[2:], **kwargs), *info[2:])

        if init_state is None:
            state = self.init(rng_key, *args, **kwargs)
        else:
            state = init_state
        loss = self.evaluate(state, *args, **kwargs)
        auxiliaries, last_res = fori_collect(
            0,
            num_steps,
            lambda info: bodyfn(0, info),
            (state, loss, *args),
            progbar=progress_bar,
            transform=collect_fn,
            return_last_val=True,
        )
        state = last_res[0]
        return SteinVIRunResult(self.get_params(state), state, auxiliaries)
Ejemplo n.º 4
0
def test_dirichlet_categorical(algo, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = sample('p_latent', dist.Dirichlet(concentration))
        sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
    init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data)
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.,
                            num_warmup=warmup_steps,
                            progbar=False,
                            dense_mass=dense_mass)
    hmc_states = fori_collect(num_samples, sample_kernel, hmc_state,
                              transform=lambda x: constrain_fn(x.z),
                              progbar=False)
    assert_allclose(np.mean(hmc_states['p_latent'], 0), true_probs, atol=0.02)

    if 'JAX_ENABLE_x64' in os.environ:
        assert hmc_states['p_latent'].dtype == np.float64
Ejemplo n.º 5
0
 def single_chain_mcmc(rng, init_params):
     hmc_state = init_kernel(init_params, num_warmup, run_warmup=False, rng=rng,
                             **sampler_kwargs)
     samples = fori_collect(num_warmup, num_warmup + num_samples, sample_kernel, hmc_state,
                            transform=lambda x: constrain_fn(x.z),
                            progbar=progbar)
     return samples
Ejemplo n.º 6
0
def test_binomial_stable(with_logits):
    # Ref: https://github.com/pyro-ppl/pyro/issues/1706
    warmup_steps, num_samples = 200, 200

    def model(data):
        p = numpyro.sample('p', dist.Beta(1., 1.))
        if with_logits:
            logits = logit(p)
            numpyro.sample('obs',
                           dist.Binomial(data['n'], logits=logits),
                           obs=data['x'])
        else:
            numpyro.sample('obs',
                           dist.Binomial(data['n'], probs=p),
                           obs=data['x'])

    data = {'n': 5000000, 'x': 3849}
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, data)
    init_kernel, sample_kernel = hmc(potential_fn)
    hmc_state = init_kernel(init_params, num_warmup=warmup_steps)
    samples = fori_collect(0,
                           num_samples,
                           sample_kernel,
                           hmc_state,
                           transform=lambda x: constrain_fn(x.z))

    assert_allclose(np.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['p'].dtype == np.float64
Ejemplo n.º 7
0
 def _single_chain_mcmc(self,
                        init,
                        collect_fields=('z', ),
                        collect_warmup=False,
                        args=(),
                        kwargs={}):
     rng_key, init_params = init
     init_state, constrain_fn = self.sampler.init(rng_key,
                                                  self.num_warmup,
                                                  init_params,
                                                  model_args=args,
                                                  model_kwargs=kwargs)
     if self.constrain_fn is None:
         constrain_fn = identity if constrain_fn is None else constrain_fn
     else:
         constrain_fn = self.constrain_fn
     collect_fn = attrgetter(*collect_fields)
     lower = 0 if collect_warmup else self.num_warmup
     states = fori_collect(
         lower,
         self.num_warmup + self.num_samples,
         self.sampler.sample,
         init_state,
         transform=collect_fn,
         progbar=self.progress_bar,
         progbar_desc=functools.partial(get_progbar_desc_str,
                                        self.num_warmup),
         diagnostics_fn=get_diagnostics_str if rng_key.ndim == 1 else None)
     if len(collect_fields) == 1:
         states = (states, )
     states = dict(zip(collect_fields, states))
     states['z'] = vmap(constrain_fn)(
         states['z']) if len(tree_flatten(states)[0]) > 0 else states['z']
     return states
Ejemplo n.º 8
0
def main(args):
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
    model_info = initialize_model(init_rng_key, model, model_args=(returns,))
    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
    hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key)
    hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state,
                              transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z),
                              progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    print_results(hmc_states, dates)

    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
    dates = mdates.num2date(mdates.datestr2num(dates))
    ax.plot(dates, returns, lw=0.5)
    # format the ticks
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax.xaxis.set_minor_locator(mdates.MonthLocator())

    ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01)
    legend = ax.legend(['returns', 'volatility'], loc='upper right')
    legend.legendHandles[1].set_alpha(0.6)
    ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time')

    plt.savefig("stochastic_volatility_plot.pdf")
Ejemplo n.º 9
0
def test_change_point():
    # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
    warmup_steps, num_samples = 500, 3000

    def model(data):
        alpha = 1 / np.mean(data)
        lambda1 = sample('lambda1', dist.Exponential(alpha))
        lambda2 = sample('lambda2', dist.Exponential(alpha))
        tau = sample('tau', dist.Uniform(0, 1))
        lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2)
        sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = np.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])
    init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(4), model, count_data)
    init_kernel, sample_kernel = hmc(potential_fn)
    hmc_state = init_kernel(init_params, num_warmup=warmup_steps)
    samples = fori_collect(num_samples, sample_kernel, hmc_state,
                           transform=lambda x: constrain_fn(x.z))
    tau_posterior = (samples['tau'] * len(count_data)).astype("int")
    tau_values, counts = onp.unique(tau_posterior, return_counts=True)
    mode_ind = np.argmax(counts)
    mode = tau_values[mode_ind]
    assert mode == 44

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['lambda1'].dtype == np.float64
        assert samples['lambda2'].dtype == np.float64
        assert samples['tau'].dtype == np.float64
Ejemplo n.º 10
0
def test_beta_bernoulli(algo):
    warmup_steps, num_samples = 500, 20000

    def model(data):
        alpha = np.array([1.1, 1.1])
        beta = np.array([1.1, 1.1])
        p_latent = sample('p_latent', dist.Beta(alpha, beta))
        sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), size=(1000, 2))
    init_params, potential_fn, constrain_fn = initialize_model(
        random.PRNGKey(2), model, data)
    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    hmc_state = init_kernel(init_params,
                            trajectory_length=1.,
                            num_warmup=warmup_steps,
                            progbar=False)
    hmc_states = fori_collect(num_samples,
                              sample_kernel,
                              hmc_state,
                              transform=lambda x: constrain_fn(x.z),
                              progbar=False)
    assert_allclose(np.mean(hmc_states['p_latent'], 0), true_probs, atol=0.05)
Ejemplo n.º 11
0
 def _single_chain_mcmc(self, init, args, kwargs, collect_fields):
     rng_key, init_state, init_params = init
     if init_state is None:
         init_state = self.sampler.init(
             rng_key,
             self.num_warmup,
             init_params,
             model_args=args,
             model_kwargs=kwargs,
         )
     sample_fn, postprocess_fn = self._get_cached_fns()
     diagnostics = (
         lambda x: self.sampler.get_diagnostics_str(x[0])
         if rng_key.ndim == 1
         else ""
     )  # noqa: E731
     init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
     lower_idx = self._collection_params["lower"]
     upper_idx = self._collection_params["upper"]
     phase = self._collection_params["phase"]
     collection_size = self._collection_params["collection_size"]
     collection_size = (
         collection_size
         if collection_size is None
         else collection_size // self.thinning
     )
     collect_vals = fori_collect(
         lower_idx,
         upper_idx,
         sample_fn,
         init_val,
         transform=_collect_fn(collect_fields),
         progbar=self.progress_bar,
         return_last_val=True,
         thinning=self.thinning,
         collection_size=collection_size,
         progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
         diagnostics_fn=diagnostics,
         num_chains=self.num_chains if self.chain_method == "parallel" else 1,
     )
     states, last_val = collect_vals
     # Get first argument of type `HMCState`
     last_state = last_val[0]
     if len(collect_fields) == 1:
         states = (states,)
     states = dict(zip(collect_fields, states))
     # Apply constraints if number of samples is non-zero
     site_values = tree_flatten(states[self._sample_field])[0]
     # XXX: lax.map still works if some arrays have 0 size
     # so we only need to filter out the case site_value.shape[0] == 0
     # (which happens when lower_idx==upper_idx)
     if len(site_values) > 0 and jnp.shape(site_values[0])[0] > 0:
         if self._jit_model_args:
             states[self._sample_field] = postprocess_fn(
                 states[self._sample_field], args, kwargs
             )
         else:
             states[self._sample_field] = postprocess_fn(states[self._sample_field])
     return states, last_state
Ejemplo n.º 12
0
def test_fori_collect():
    def f(x):
        return {"i": x["i"] + x["j"], "j": x["i"] - x["j"]}

    a = {"i": jnp.array([0.0]), "j": jnp.array([1.0])}
    expected_tree = {"i": jnp.array([[0.0], [2.0]])}
    actual_tree = fori_collect(1, 3, f, a, transform=lambda a: {"i": a["i"]})
    check_eq(actual_tree, expected_tree)
Ejemplo n.º 13
0
def run_inference(dept, male, applications, admit, rng, args):
    init_params, potential_fn, constrain_fn = initialize_model(
        rng, glmm, dept, male, applications, admit)
    init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
    hmc_state = init_kernel(init_params, args.num_warmup_steps)
    hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
                              transform=lambda hmc_state: constrain_fn(hmc_state.z))
    return hmc_states
Ejemplo n.º 14
0
def test_fori_collect():
    def f(x):
        return {'i': x['i'] + x['j'], 'j': x['i'] - x['j']}

    a = {'i': jnp.array([0.]), 'j': jnp.array([1.])}
    expected_tree = {'i': jnp.array([[0.], [2.]])}
    actual_tree = fori_collect(1, 3, f, a, transform=lambda a: {'i': a['i']})
    check_eq(actual_tree, expected_tree)
Ejemplo n.º 15
0
def test_fori_collect_thinning():
    def f(x):
        return x + 1.0

    actual2 = fori_collect(0, 9, f, jnp.array([-1]), thinning=2)
    expected2 = jnp.array([[2], [4], [6], [8]])
    check_eq(actual2, expected2)

    actual3 = fori_collect(0, 9, f, jnp.array([-1]), thinning=3)
    expected3 = jnp.array([[2], [5], [8]])
    check_eq(actual3, expected3)

    actual4 = fori_collect(0, 9, f, jnp.array([-1]), thinning=4)
    expected4 = jnp.array([[4], [8]])
    check_eq(actual4, expected4)

    actual5 = fori_collect(12, 37, f, jnp.array([-1]), thinning=5)
    expected5 = jnp.array([[16], [21], [26], [31], [36]])
    check_eq(actual5, expected5)
Ejemplo n.º 16
0
def mcmc(num_warmup,
         num_samples,
         init_params,
         sampler='hmc',
         constrain_fn=None,
         print_summary=True,
         **sampler_kwargs):
    """
    Convenience wrapper for MCMC samplers -- runs warmup, prints
    diagnostic summary and returns a collections of samples
    from the posterior.

    :param num_warmup: Number of warmup steps.
    :param num_samples: Number of samples to generate from the Markov chain.
    :param init_params: Initial parameters to begin sampling. The type can
        must be consistent with the input type to `potential_fn`.
    :param sampler: currently, only `hmc` is implemented (default).
    :param constrain_fn: Callable that converts a collection of unconstrained
        sample values returned from the sampler to constrained values that
        lie within the support of the sample sites.
    :param print_summary: Whether to print diagnostics summary for
        each sample site. Default is ``True``.
    :param `**sampler_kwargs`: Sampler specific keyword arguments.

         - *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and
           :func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note
           that all arguments must be provided as keywords.

    :return: collection of samples from the posterior.
    """
    if sampler == 'hmc':
        if constrain_fn is None:
            constrain_fn = identity
        potential_fn = sampler_kwargs.pop('potential_fn')
        kinetic_fn = sampler_kwargs.pop('kinetic_fn', None)
        algo = sampler_kwargs.pop('algo', 'NUTS')
        progbar = sampler_kwargs.pop('progbar', True)

        init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo)
        hmc_state = init_kernel(init_params,
                                num_warmup,
                                progbar=progbar,
                                **sampler_kwargs)
        samples = fori_collect(num_samples,
                               sample_kernel,
                               hmc_state,
                               transform=lambda x: constrain_fn(x.z),
                               progbar=progbar,
                               diagnostics_fn=get_diagnostics_str,
                               progbar_desc='sample')
        if print_summary:
            summary(samples)
        return samples
    else:
        raise ValueError('sampler: {} not recognized'.format(sampler))
Ejemplo n.º 17
0
def main(args):
    jax_config.update('jax_platform_name', args.device)
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng, sample_rng = random.split(random.PRNGKey(args.rng))
    init_params, potential_fn, constrain_fn = initialize_model(init_rng, model, returns)
    init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
    hmc_state = init_kernel(init_params, args.num_warmup, rng=sample_rng)
    hmc_states = fori_collect(0, args.num_samples, sample_kernel, hmc_state,
                              transform=lambda hmc_state: constrain_fn(hmc_state.z))
    print_results(hmc_states, dates)
Ejemplo n.º 18
0
def test_construct():
    import bellini
    from bellini import Quantity, Species, Substance, Story
    import pint
    ureg = pint.UnitRegistry()

    s = Story()
    water = Species(name='water')
    s.one_water_quantity = bellini.distributions.Normal(
        loc=Quantity(3.0, ureg.mole),
        scale=Quantity(0.01, ureg.mole),
        name="first_normal",
    )
    s.another_water_quantity = bellini.distributions.Normal(
        loc=Quantity(3.0, ureg.mole),
        scale=Quantity(0.01, ureg.mole),
        name="second_normal",
    )
    s.combined_water = s.one_water_quantity + s.another_water_quantity
    # s.combined_water.observed = True
    s.combined_water.name = "combined_water"

    s.combined_water_with_nose = bellini.distributions.Normal(
        loc=s.combined_water,
        scale=Quantity(0.01, ureg.mole),
        name="combined_with_noise")

    s.combined_water_with_nose.observed = True

    from bellini.api._numpyro import graph_to_numpyro_model
    model = graph_to_numpyro_model(s.g)

    from numpyro.infer.util import initialize_model
    import jax
    model_info = initialize_model(
        jax.random.PRNGKey(2666),
        model,
    )
    from numpyro.infer.hmc import hmc
    from numpyro.util import fori_collect

    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
    hmc_state = init_kernel(model_info.param_info,
                            trajectory_length=10,
                            num_warmup=300)
    samples = fori_collect(
        0,
        500,
        sample_kernel,
        hmc_state,
        transform=lambda state: model_info.postprocess_fn(state.z))

    print(samples)
Ejemplo n.º 19
0
def test_fori_collect_return_last(progbar):
    def f(x):
        x['i'] = x['i'] + 1
        return x

    tree, init_state = fori_collect(2, 4, f, {'i': 0},
                                    transform=lambda a: {'i': a['i']},
                                    return_last_val=True,
                                    progbar=progbar)
    expected_tree = {'i': jnp.array([3, 4])}
    expected_last_state = {'i': jnp.array(4)}
    check_eq(init_state, expected_last_state)
    check_eq(tree, expected_tree)
Ejemplo n.º 20
0
def test_unnormalized_normal(algo):
    true_mean, true_std = 1., 2.
    warmup_steps, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * np.sum(((z - true_mean) / true_std)**2)

    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    init_params = np.array(0.)
    hmc_state = init_kernel(init_params,
                            trajectory_length=9,
                            num_warmup=warmup_steps)
    hmc_states = fori_collect(num_samples,
                              sample_kernel,
                              hmc_state,
                              transform=lambda x: x.z)
    assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05)
    assert_allclose(np.std(hmc_states), true_std, rtol=0.05)
Ejemplo n.º 21
0
def run_inference(transition_prior, emission_prior, supervised_categories,
                  supervised_words, unsupervised_words, rng, args):
    init_params, potential_fn, constrain_fn = initialize_model(
        rng,
        semi_supervised_hmm,
        transition_prior,
        emission_prior,
        supervised_categories,
        supervised_words,
        unsupervised_words,
    )
    init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
    hmc_state = init_kernel(init_params, args.num_warmup)
    hmc_states = fori_collect(args.num_samples,
                              sample_kernel,
                              hmc_state,
                              transform=lambda state: constrain_fn(state.z))
    return hmc_states
Ejemplo n.º 22
0
    def _single_chain_mcmc(self,
                           rng_key,
                           init_state,
                           init_params,
                           args,
                           kwargs,
                           collect_fields=('z', )):
        if init_state is None:
            init_state = self.sampler.init(rng_key,
                                           self.num_warmup,
                                           init_params,
                                           model_args=args,
                                           model_kwargs=kwargs)
        if self.constrain_fn is None:
            self.constrain_fn = self.sampler.constrain_fn(args, kwargs)
        diagnostics = lambda x: get_diagnostics_str(x[
            0]) if rng_key.ndim == 1 else None  # noqa: E731
        init_val = (init_state, args,
                    kwargs) if self._jit_model_args else (init_state, )
        lower_idx = self._collection_params["lower"]
        upper_idx = self._collection_params["upper"]

        collect_vals = fori_collect(
            lower_idx,
            upper_idx,
            self._get_cached_fn(),
            init_val,
            transform=_collect_fn(collect_fields),
            progbar=self.progress_bar,
            return_last_val=True,
            collection_size=self._collection_params["collection_size"],
            progbar_desc=functools.partial(get_progbar_desc_str, lower_idx),
            diagnostics_fn=diagnostics)
        states, last_val = collect_vals
        # Get first argument of type `HMCState`
        last_state = last_val[0]
        if len(collect_fields) == 1:
            states = (states, )
        states = dict(zip(collect_fields, states))
        # Apply constraints if number of samples is non-zero
        site_values = tree_flatten(states['z'])[0]
        if len(site_values) > 0 and site_values[0].size > 0:
            states['z'] = lax.map(self.constrain_fn, states['z'])
        return states, last_state
Ejemplo n.º 23
0
def benchmark_hmc(args, features, labels):
    trajectory_length = step_size * args.num_steps
    _, potential_fn, _ = initialize_model(random.PRNGKey(1), model, features, labels)
    init_kernel, sample_kernel = hmc(potential_fn, algo=args.algo)
    t0 = time.time()
    # TODO: Use init_params from `initialize_model` instead of fixed params.
    hmc_state, _, _ = init_kernel(init_params, num_warmup=0, step_size=step_size,
                                  trajectory_length=trajectory_length,
                                  adapt_step_size=False, run_warmup=False)
    t1 = time.time()
    print("time for hmc_init: ", t1 - t0)

    def transform(state): return {'coefs': state.z['coefs'],
                                  'num_steps': state.num_steps}

    hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, transform=transform)
    num_leapfrogs = np.sum(hmc_states['num_steps'])
    print('number of leapfrog steps: ', num_leapfrogs)
    print('avg. time for each step: ', (time.time() - t1) / num_leapfrogs)
Ejemplo n.º 24
0
def test_functional_map(algo, map_fn):
    if map_fn is pmap and jax.device_count() == 1:
        pytest.skip("pmap test requires device_count greater than 1.")

    true_mean, true_std = 1.0, 2.0
    num_warmup, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * jnp.sum(((z - true_mean) / true_std)**2)

    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    init_params = jnp.array([0.0, -1.0])
    rng_keys = random.split(random.PRNGKey(0), 2)

    init_kernel_map = map_fn(
        lambda init_param, rng_key: init_kernel(init_param,
                                                trajectory_length=9,
                                                num_warmup=num_warmup,
                                                rng_key=rng_key))
    init_states = init_kernel_map(init_params, rng_keys)

    fori_collect_map = map_fn(lambda hmc_state: fori_collect(
        0,
        num_samples,
        sample_kernel,
        hmc_state,
        transform=lambda x: x.z,
        progbar=False,
    ))
    chain_samples = fori_collect_map(init_states)

    assert_allclose(jnp.mean(chain_samples, axis=1),
                    jnp.repeat(true_mean, 2),
                    rtol=0.06)
    assert_allclose(jnp.std(chain_samples, axis=1),
                    jnp.repeat(true_std, 2),
                    rtol=0.06)
Ejemplo n.º 25
0
def test_map(algo, map_fn):
    if map_fn is pmap and xla_bridge.device_count() == 1:
        pytest.skip('pmap test requires device_count greater than 1.')

    true_mean, true_std = 1., 2.
    warmup_steps, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * np.sum(((z - true_mean) / true_std)**2)

    init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
    init_params = np.array([0., -1.])
    rngs = random.split(random.PRNGKey(0), 2)

    init_kernel_map = map_fn(
        lambda init_param, rng: init_kernel(init_param,
                                            trajectory_length=9,
                                            num_warmup=warmup_steps,
                                            progbar=False,
                                            rng=rng))
    init_states = init_kernel_map(init_params, rngs)

    fori_collect_map = map_fn(
        lambda hmc_state: fori_collect(0,
                                       num_samples,
                                       sample_kernel,
                                       hmc_state,
                                       transform=lambda x: x.z,
                                       progbar=False))
    chain_samples = fori_collect_map(init_states)

    assert_allclose(np.mean(chain_samples, axis=1),
                    np.repeat(true_mean, 2),
                    rtol=0.05)
    assert_allclose(np.std(chain_samples, axis=1),
                    np.repeat(true_std, 2),
                    rtol=0.05)
Ejemplo n.º 26
0
def mcmc(num_warmup, num_samples, init_params, sampler='hmc',
         constrain_fn=None, print_summary=True, **sampler_kwargs):
    """
    Convenience wrapper for MCMC samplers -- runs warmup, prints
    diagnostic summary and returns a collections of samples
    from the posterior.

    :param num_warmup: Number of warmup steps.
    :param num_samples: Number of samples to generate from the Markov chain.
    :param init_params: Initial parameters to begin sampling. The type can
        must be consistent with the input type to `potential_fn`.
    :param sampler: currently, only `hmc` is implemented (default).
    :param constrain_fn: Callable that converts a collection of unconstrained
        sample values returned from the sampler to constrained values that
        lie within the support of the sample sites.
    :param print_summary: Whether to print diagnostics summary for
        each sample site. Default is ``True``.
    :param `**sampler_kwargs`: Sampler specific keyword arguments.

         - *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and
           :func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note
           that all arguments must be provided as keywords.

    :return: collection of samples from the posterior.

    .. testsetup::

       import jax
       from jax import random
       import jax.numpy as np
       import numpyro.distributions as dist
       from numpyro.handlers import sample
       from numpyro.hmc_util import initialize_model
       from numpyro.mcmc import hmc
       from numpyro.util import fori_collect

    .. doctest::

        >>> true_coefs = np.array([1., 2., 3.])
        >>> data = random.normal(random.PRNGKey(2), (2000, 3))
        >>> dim = 3
        >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
        >>>
        >>> def model(data, labels):
        ...     coefs_mean = np.zeros(dim)
        ...     coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3)))
        ...     intercept = sample('intercept', dist.Normal(0., 10.))
        ...     return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
        >>>
        >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), model,
        ...                                                            data, labels)
        >>> num_warmup, num_samples = 1000, 1000
        >>> samples = mcmc(num_warmup, num_samples, init_params,
        ...                potential_fn=potential_fn,
        ...                constrain_fn=constrain_fn)  # doctest: +SKIP
        warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79]
        sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]


                           mean         sd       5.5%      94.5%      n_eff       Rhat
            coefs[0]       0.96       0.07       0.85       1.07     455.35       1.01
            coefs[1]       2.05       0.09       1.91       2.20     332.00       1.01
            coefs[2]       3.18       0.13       2.96       3.37     320.27       1.00
           intercept      -0.03       0.02      -0.06       0.00     402.53       1.00
    """
    if sampler == 'hmc':
        if constrain_fn is None:
            constrain_fn = identity
        potential_fn = sampler_kwargs.pop('potential_fn')
        kinetic_fn = sampler_kwargs.pop('kinetic_fn', None)
        algo = sampler_kwargs.pop('algo', 'NUTS')
        progbar = sampler_kwargs.pop('progbar', True)

        init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo)
        hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, **sampler_kwargs)
        samples = fori_collect(num_samples, sample_kernel, hmc_state,
                               transform=lambda x: constrain_fn(x.z),
                               progbar=progbar,
                               diagnostics_fn=get_diagnostics_str,
                               progbar_desc='sample')
        if print_summary:
            summary(samples)
        return samples
    else:
        raise ValueError('sampler: {} not recognized'.format(sampler))
Ejemplo n.º 27
0
def mcmc(num_warmup, num_samples, init_params, num_chains=1, sampler='hmc',
         constrain_fn=None, print_summary=True, **sampler_kwargs):
    """
    Convenience wrapper for MCMC samplers -- runs warmup, prints
    diagnostic summary and returns a collections of samples
    from the posterior.

    :param num_warmup: Number of warmup steps.
    :param num_samples: Number of samples to generate from the Markov chain.
    :param init_params: Initial parameters to begin sampling. The type can
        must be consistent with the input type to `potential_fn`.
    :param sampler: currently, only `hmc` is implemented (default).
    :param constrain_fn: Callable that converts a collection of unconstrained
        sample values returned from the sampler to constrained values that
        lie within the support of the sample sites.
    :param print_summary: Whether to print diagnostics summary for
        each sample site. Default is ``True``.
    :param `**sampler_kwargs`: Sampler specific keyword arguments.

         - *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and
           :func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note
           that all arguments must be provided as keywords.

    :return: collection of samples from the posterior.

    .. testsetup::

       import jax
       from jax import random
       import jax.numpy as np
       import numpyro.distributions as dist
       from numpyro.handlers import sample
       from numpyro.hmc_util import initialize_model
       from numpyro.mcmc import hmc
       from numpyro.util import fori_collect

    .. doctest::

        >>> true_coefs = np.array([1., 2., 3.])
        >>> data = random.normal(random.PRNGKey(2), (2000, 3))
        >>> dim = 3
        >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
        >>>
        >>> def model(data, labels):
        ...     coefs_mean = np.zeros(dim)
        ...     coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3)))
        ...     intercept = sample('intercept', dist.Normal(0., 10.))
        ...     return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
        >>>
        >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), model,
        ...                                                            data, labels)
        >>> num_warmup, num_samples = 1000, 1000
        >>> samples = mcmc(num_warmup, num_samples, init_params,
        ...                potential_fn=potential_fn,
        ...                constrain_fn=constrain_fn)  # doctest: +SKIP
        warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79]
        sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]


                           mean         sd       5.5%      94.5%      n_eff       Rhat
            coefs[0]       0.96       0.07       0.85       1.07     455.35       1.01
            coefs[1]       2.05       0.09       1.91       2.20     332.00       1.01
            coefs[2]       3.18       0.13       2.96       3.37     320.27       1.00
           intercept      -0.03       0.02      -0.06       0.00     402.53       1.00
    """
    sequential_chain = False
    if xla_bridge.device_count() < num_chains:
        sequential_chain = True
        warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.'
                      ' Chains will be drawn sequentially. If you are running `mcmc` in CPU,'
                      ' consider to disable XLA intra-op parallelism by setting the environment'
                      ' flag "XLA_FLAGS=--xla_force_host_platform_device_count={}".'
                      .format(num_chains, xla_bridge.device_count(), num_chains))
    progbar = sampler_kwargs.pop('progbar', True)
    if num_chains > 1:
        progbar = False

    if sampler == 'hmc':
        if constrain_fn is None:
            constrain_fn = identity
        potential_fn = sampler_kwargs.pop('potential_fn')
        kinetic_fn = sampler_kwargs.pop('kinetic_fn', None)
        algo = sampler_kwargs.pop('algo', 'NUTS')
        if num_chains > 1:
            rngs = sampler_kwargs.pop('rng', vmap(PRNGKey)(np.arange(num_chains)))
        else:
            rng = sampler_kwargs.pop('rng', PRNGKey(0))

        init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo)
        if progbar:
            hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, rng=rng,
                                    **sampler_kwargs)
            samples_flat = fori_collect(0, num_samples, sample_kernel, hmc_state,
                                        transform=lambda x: constrain_fn(x.z),
                                        progbar=progbar,
                                        diagnostics_fn=get_diagnostics_str,
                                        progbar_desc='sample')
            samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat)
        else:
            def single_chain_mcmc(rng, init_params):
                hmc_state = init_kernel(init_params, num_warmup, run_warmup=False, rng=rng,
                                        **sampler_kwargs)
                samples = fori_collect(num_warmup, num_warmup + num_samples, sample_kernel, hmc_state,
                                       transform=lambda x: constrain_fn(x.z),
                                       progbar=progbar)
                return samples

            if num_chains == 1:
                samples_flat = single_chain_mcmc(rng, init_params)
                samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat)
            else:
                if sequential_chain:
                    samples = lax.map(lambda args: single_chain_mcmc(*args), (rngs, init_params))
                else:
                    samples = pmap(single_chain_mcmc)(rngs, init_params)
                samples_flat = tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), samples)

        if print_summary:
            summary(samples)
        return samples_flat
    else:
        raise ValueError('sampler: {} not recognized'.format(sampler))