def test_reuse_mcmc_pe_gen(): y1 = onp.random.normal(3, 0.1, (100, )) y2 = onp.random.normal(-3, 0.1, (100, )) def model(y_obs): mu = numpyro.sample('mu', dist.Normal(0., 1.)) sigma = numpyro.sample("sigma", dist.HalfCauchy(3.)) numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(0), model, y1, dynamic_args=True) init_kernel, sample_kernel = hmc(potential_fn_gen=potential_fn) init_state = init_kernel(init_params, num_warmup=300, model_args=(y1, )) @jit def _sample(state_and_args): hmc_state, model_args = state_and_args return sample_kernel(hmc_state, (model_args, )), model_args samples = fori_collect(0, 500, _sample, (init_state, y1), transform=lambda state: constrain_fn(y1) (state[0].z)) assert_allclose(samples['mu'].mean(), 3., atol=0.1) # Run on data, re-using `mcmc` - this should be much faster. init_state = init_kernel(init_params, num_warmup=300, model_args=(y2, )) samples = fori_collect(0, 500, _sample, (init_state, y2), transform=lambda state: constrain_fn(y2) (state[0].z)) assert_allclose(samples['mu'].mean(), -3., atol=0.1)
def test_functional_beta_bernoulli_x64(algo): warmup_steps, num_samples = 500, 20000 def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta)) numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, model_args=(data, )) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1., num_warmup=warmup_steps) samples = fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z)) assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05) if 'JAX_ENABLE_X64' in os.environ: assert samples['p_latent'].dtype == np.float64
def run( self, rng_key, num_steps, *args, progress_bar=True, init_state=None, collect_fn=lambda val: val[1], # TODO: refactor **kwargs, ): def bodyfn(_i, info): body_state = info[0] return (*self.update(body_state, *info[2:], **kwargs), *info[2:]) if init_state is None: state = self.init(rng_key, *args, **kwargs) else: state = init_state loss = self.evaluate(state, *args, **kwargs) auxiliaries, last_res = fori_collect( 0, num_steps, lambda info: bodyfn(0, info), (state, loss, *args), progbar=progress_bar, transform=collect_fn, return_last_val=True, ) state = last_res[0] return SteinVIRunResult(self.get_params(state), state, auxiliaries)
def test_dirichlet_categorical(algo, dense_mass): warmup_steps, num_samples = 100, 20000 def model(data): concentration = np.array([1.0, 1.0, 1.0]) p_latent = sample('p_latent', dist.Dirichlet(concentration)) sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = np.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,)) init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1., num_warmup=warmup_steps, progbar=False, dense_mass=dense_mass) hmc_states = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=False) assert_allclose(np.mean(hmc_states['p_latent'], 0), true_probs, atol=0.02) if 'JAX_ENABLE_x64' in os.environ: assert hmc_states['p_latent'].dtype == np.float64
def single_chain_mcmc(rng, init_params): hmc_state = init_kernel(init_params, num_warmup, run_warmup=False, rng=rng, **sampler_kwargs) samples = fori_collect(num_warmup, num_warmup + num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=progbar) return samples
def test_binomial_stable(with_logits): # Ref: https://github.com/pyro-ppl/pyro/issues/1706 warmup_steps, num_samples = 200, 200 def model(data): p = numpyro.sample('p', dist.Beta(1., 1.)) if with_logits: logits = logit(p) numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x']) else: numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x']) data = {'n': 5000000, 'x': 3849} init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, data) init_kernel, sample_kernel = hmc(potential_fn) hmc_state = init_kernel(init_params, num_warmup=warmup_steps) samples = fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z)) assert_allclose(np.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05) if 'JAX_ENABLE_x64' in os.environ: assert samples['p'].dtype == np.float64
def _single_chain_mcmc(self, init, collect_fields=('z', ), collect_warmup=False, args=(), kwargs={}): rng_key, init_params = init init_state, constrain_fn = self.sampler.init(rng_key, self.num_warmup, init_params, model_args=args, model_kwargs=kwargs) if self.constrain_fn is None: constrain_fn = identity if constrain_fn is None else constrain_fn else: constrain_fn = self.constrain_fn collect_fn = attrgetter(*collect_fields) lower = 0 if collect_warmup else self.num_warmup states = fori_collect( lower, self.num_warmup + self.num_samples, self.sampler.sample, init_state, transform=collect_fn, progbar=self.progress_bar, progbar_desc=functools.partial(get_progbar_desc_str, self.num_warmup), diagnostics_fn=get_diagnostics_str if rng_key.ndim == 1 else None) if len(collect_fields) == 1: states = (states, ) states = dict(zip(collect_fields, states)) states['z'] = vmap(constrain_fn)( states['z']) if len(tree_flatten(states)[0]) > 0 else states['z'] return states
def main(args): _, fetch = load_dataset(SP500, shuffle=False) dates, returns = fetch() init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed)) model_info = initialize_model(init_rng_key, model, model_args=(returns,)) init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key) hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z), progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) print_results(hmc_states, dates) fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) dates = mdates.num2date(mdates.datestr2num(dates)) ax.plot(dates, returns, lw=0.5) # format the ticks ax.xaxis.set_major_locator(mdates.YearLocator()) ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y')) ax.xaxis.set_minor_locator(mdates.MonthLocator()) ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01) legend = ax.legend(['returns', 'volatility'], loc='upper right') legend.legendHandles[1].set_alpha(0.6) ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time') plt.savefig("stochastic_volatility_plot.pdf")
def test_change_point(): # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696 warmup_steps, num_samples = 500, 3000 def model(data): alpha = 1 / np.mean(data) lambda1 = sample('lambda1', dist.Exponential(alpha)) lambda2 = sample('lambda2', dist.Exponential(alpha)) tau = sample('tau', dist.Uniform(0, 1)) lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2) sample('obs', dist.Poisson(lambda12), obs=data) count_data = np.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(4), model, count_data) init_kernel, sample_kernel = hmc(potential_fn) hmc_state = init_kernel(init_params, num_warmup=warmup_steps) samples = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z)) tau_posterior = (samples['tau'] * len(count_data)).astype("int") tau_values, counts = onp.unique(tau_posterior, return_counts=True) mode_ind = np.argmax(counts) mode = tau_values[mode_ind] assert mode == 44 if 'JAX_ENABLE_x64' in os.environ: assert samples['lambda1'].dtype == np.float64 assert samples['lambda2'].dtype == np.float64 assert samples['tau'].dtype == np.float64
def test_beta_bernoulli(algo): warmup_steps, num_samples = 500, 20000 def model(data): alpha = np.array([1.1, 1.1]) beta = np.array([1.1, 1.1]) p_latent = sample('p_latent', dist.Beta(alpha, beta)) sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), size=(1000, 2)) init_params, potential_fn, constrain_fn = initialize_model( random.PRNGKey(2), model, data) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1., num_warmup=warmup_steps, progbar=False) hmc_states = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=False) assert_allclose(np.mean(hmc_states['p_latent'], 0), true_probs, atol=0.05)
def _single_chain_mcmc(self, init, args, kwargs, collect_fields): rng_key, init_state, init_params = init if init_state is None: init_state = self.sampler.init( rng_key, self.num_warmup, init_params, model_args=args, model_kwargs=kwargs, ) sample_fn, postprocess_fn = self._get_cached_fns() diagnostics = ( lambda x: self.sampler.get_diagnostics_str(x[0]) if rng_key.ndim == 1 else "" ) # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) lower_idx = self._collection_params["lower"] upper_idx = self._collection_params["upper"] phase = self._collection_params["phase"] collection_size = self._collection_params["collection_size"] collection_size = ( collection_size if collection_size is None else collection_size // self.thinning ) collect_vals = fori_collect( lower_idx, upper_idx, sample_fn, init_val, transform=_collect_fn(collect_fields), progbar=self.progress_bar, return_last_val=True, thinning=self.thinning, collection_size=collection_size, progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase), diagnostics_fn=diagnostics, num_chains=self.num_chains if self.chain_method == "parallel" else 1, ) states, last_val = collect_vals # Get first argument of type `HMCState` last_state = last_val[0] if len(collect_fields) == 1: states = (states,) states = dict(zip(collect_fields, states)) # Apply constraints if number of samples is non-zero site_values = tree_flatten(states[self._sample_field])[0] # XXX: lax.map still works if some arrays have 0 size # so we only need to filter out the case site_value.shape[0] == 0 # (which happens when lower_idx==upper_idx) if len(site_values) > 0 and jnp.shape(site_values[0])[0] > 0: if self._jit_model_args: states[self._sample_field] = postprocess_fn( states[self._sample_field], args, kwargs ) else: states[self._sample_field] = postprocess_fn(states[self._sample_field]) return states, last_state
def test_fori_collect(): def f(x): return {"i": x["i"] + x["j"], "j": x["i"] - x["j"]} a = {"i": jnp.array([0.0]), "j": jnp.array([1.0])} expected_tree = {"i": jnp.array([[0.0], [2.0]])} actual_tree = fori_collect(1, 3, f, a, transform=lambda a: {"i": a["i"]}) check_eq(actual_tree, expected_tree)
def run_inference(dept, male, applications, admit, rng, args): init_params, potential_fn, constrain_fn = initialize_model( rng, glmm, dept, male, applications, admit) init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') hmc_state = init_kernel(init_params, args.num_warmup_steps) hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: constrain_fn(hmc_state.z)) return hmc_states
def test_fori_collect(): def f(x): return {'i': x['i'] + x['j'], 'j': x['i'] - x['j']} a = {'i': jnp.array([0.]), 'j': jnp.array([1.])} expected_tree = {'i': jnp.array([[0.], [2.]])} actual_tree = fori_collect(1, 3, f, a, transform=lambda a: {'i': a['i']}) check_eq(actual_tree, expected_tree)
def test_fori_collect_thinning(): def f(x): return x + 1.0 actual2 = fori_collect(0, 9, f, jnp.array([-1]), thinning=2) expected2 = jnp.array([[2], [4], [6], [8]]) check_eq(actual2, expected2) actual3 = fori_collect(0, 9, f, jnp.array([-1]), thinning=3) expected3 = jnp.array([[2], [5], [8]]) check_eq(actual3, expected3) actual4 = fori_collect(0, 9, f, jnp.array([-1]), thinning=4) expected4 = jnp.array([[4], [8]]) check_eq(actual4, expected4) actual5 = fori_collect(12, 37, f, jnp.array([-1]), thinning=5) expected5 = jnp.array([[16], [21], [26], [31], [36]]) check_eq(actual5, expected5)
def mcmc(num_warmup, num_samples, init_params, sampler='hmc', constrain_fn=None, print_summary=True, **sampler_kwargs): """ Convenience wrapper for MCMC samplers -- runs warmup, prints diagnostic summary and returns a collections of samples from the posterior. :param num_warmup: Number of warmup steps. :param num_samples: Number of samples to generate from the Markov chain. :param init_params: Initial parameters to begin sampling. The type can must be consistent with the input type to `potential_fn`. :param sampler: currently, only `hmc` is implemented (default). :param constrain_fn: Callable that converts a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites. :param print_summary: Whether to print diagnostics summary for each sample site. Default is ``True``. :param `**sampler_kwargs`: Sampler specific keyword arguments. - *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and :func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note that all arguments must be provided as keywords. :return: collection of samples from the posterior. """ if sampler == 'hmc': if constrain_fn is None: constrain_fn = identity potential_fn = sampler_kwargs.pop('potential_fn') kinetic_fn = sampler_kwargs.pop('kinetic_fn', None) algo = sampler_kwargs.pop('algo', 'NUTS') progbar = sampler_kwargs.pop('progbar', True) init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo) hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, **sampler_kwargs) samples = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=progbar, diagnostics_fn=get_diagnostics_str, progbar_desc='sample') if print_summary: summary(samples) return samples else: raise ValueError('sampler: {} not recognized'.format(sampler))
def main(args): jax_config.update('jax_platform_name', args.device) _, fetch = load_dataset(SP500, shuffle=False) dates, returns = fetch() init_rng, sample_rng = random.split(random.PRNGKey(args.rng)) init_params, potential_fn, constrain_fn = initialize_model(init_rng, model, returns) init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') hmc_state = init_kernel(init_params, args.num_warmup, rng=sample_rng) hmc_states = fori_collect(0, args.num_samples, sample_kernel, hmc_state, transform=lambda hmc_state: constrain_fn(hmc_state.z)) print_results(hmc_states, dates)
def test_construct(): import bellini from bellini import Quantity, Species, Substance, Story import pint ureg = pint.UnitRegistry() s = Story() water = Species(name='water') s.one_water_quantity = bellini.distributions.Normal( loc=Quantity(3.0, ureg.mole), scale=Quantity(0.01, ureg.mole), name="first_normal", ) s.another_water_quantity = bellini.distributions.Normal( loc=Quantity(3.0, ureg.mole), scale=Quantity(0.01, ureg.mole), name="second_normal", ) s.combined_water = s.one_water_quantity + s.another_water_quantity # s.combined_water.observed = True s.combined_water.name = "combined_water" s.combined_water_with_nose = bellini.distributions.Normal( loc=s.combined_water, scale=Quantity(0.01, ureg.mole), name="combined_with_noise") s.combined_water_with_nose.observed = True from bellini.api._numpyro import graph_to_numpyro_model model = graph_to_numpyro_model(s.g) from numpyro.infer.util import initialize_model import jax model_info = initialize_model( jax.random.PRNGKey(2666), model, ) from numpyro.infer.hmc import hmc from numpyro.util import fori_collect init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') hmc_state = init_kernel(model_info.param_info, trajectory_length=10, num_warmup=300) samples = fori_collect( 0, 500, sample_kernel, hmc_state, transform=lambda state: model_info.postprocess_fn(state.z)) print(samples)
def test_fori_collect_return_last(progbar): def f(x): x['i'] = x['i'] + 1 return x tree, init_state = fori_collect(2, 4, f, {'i': 0}, transform=lambda a: {'i': a['i']}, return_last_val=True, progbar=progbar) expected_tree = {'i': jnp.array([3, 4])} expected_last_state = {'i': jnp.array(4)} check_eq(init_state, expected_last_state) check_eq(tree, expected_tree)
def test_unnormalized_normal(algo): true_mean, true_std = 1., 2. warmup_steps, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * np.sum(((z - true_mean) / true_std)**2) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) init_params = np.array(0.) hmc_state = init_kernel(init_params, trajectory_length=9, num_warmup=warmup_steps) hmc_states = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: x.z) assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05) assert_allclose(np.std(hmc_states), true_std, rtol=0.05)
def run_inference(transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words, rng, args): init_params, potential_fn, constrain_fn = initialize_model( rng, semi_supervised_hmm, transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words, ) init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') hmc_state = init_kernel(init_params, args.num_warmup) hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, transform=lambda state: constrain_fn(state.z)) return hmc_states
def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z', )): if init_state is None: init_state = self.sampler.init(rng_key, self.num_warmup, init_params, model_args=args, model_kwargs=kwargs) if self.constrain_fn is None: self.constrain_fn = self.sampler.constrain_fn(args, kwargs) diagnostics = lambda x: get_diagnostics_str(x[ 0]) if rng_key.ndim == 1 else None # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state, ) lower_idx = self._collection_params["lower"] upper_idx = self._collection_params["upper"] collect_vals = fori_collect( lower_idx, upper_idx, self._get_cached_fn(), init_val, transform=_collect_fn(collect_fields), progbar=self.progress_bar, return_last_val=True, collection_size=self._collection_params["collection_size"], progbar_desc=functools.partial(get_progbar_desc_str, lower_idx), diagnostics_fn=diagnostics) states, last_val = collect_vals # Get first argument of type `HMCState` last_state = last_val[0] if len(collect_fields) == 1: states = (states, ) states = dict(zip(collect_fields, states)) # Apply constraints if number of samples is non-zero site_values = tree_flatten(states['z'])[0] if len(site_values) > 0 and site_values[0].size > 0: states['z'] = lax.map(self.constrain_fn, states['z']) return states, last_state
def benchmark_hmc(args, features, labels): trajectory_length = step_size * args.num_steps _, potential_fn, _ = initialize_model(random.PRNGKey(1), model, features, labels) init_kernel, sample_kernel = hmc(potential_fn, algo=args.algo) t0 = time.time() # TODO: Use init_params from `initialize_model` instead of fixed params. hmc_state, _, _ = init_kernel(init_params, num_warmup=0, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, run_warmup=False) t1 = time.time() print("time for hmc_init: ", t1 - t0) def transform(state): return {'coefs': state.z['coefs'], 'num_steps': state.num_steps} hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, transform=transform) num_leapfrogs = np.sum(hmc_states['num_steps']) print('number of leapfrog steps: ', num_leapfrogs) print('avg. time for each step: ', (time.time() - t1) / num_leapfrogs)
def test_functional_map(algo, map_fn): if map_fn is pmap and jax.device_count() == 1: pytest.skip("pmap test requires device_count greater than 1.") true_mean, true_std = 1.0, 2.0 num_warmup, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std)**2) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) init_params = jnp.array([0.0, -1.0]) rng_keys = random.split(random.PRNGKey(0), 2) init_kernel_map = map_fn( lambda init_param, rng_key: init_kernel(init_param, trajectory_length=9, num_warmup=num_warmup, rng_key=rng_key)) init_states = init_kernel_map(init_params, rng_keys) fori_collect_map = map_fn(lambda hmc_state: fori_collect( 0, num_samples, sample_kernel, hmc_state, transform=lambda x: x.z, progbar=False, )) chain_samples = fori_collect_map(init_states) assert_allclose(jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06) assert_allclose(jnp.std(chain_samples, axis=1), jnp.repeat(true_std, 2), rtol=0.06)
def test_map(algo, map_fn): if map_fn is pmap and xla_bridge.device_count() == 1: pytest.skip('pmap test requires device_count greater than 1.') true_mean, true_std = 1., 2. warmup_steps, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * np.sum(((z - true_mean) / true_std)**2) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) init_params = np.array([0., -1.]) rngs = random.split(random.PRNGKey(0), 2) init_kernel_map = map_fn( lambda init_param, rng: init_kernel(init_param, trajectory_length=9, num_warmup=warmup_steps, progbar=False, rng=rng)) init_states = init_kernel_map(init_params, rngs) fori_collect_map = map_fn( lambda hmc_state: fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: x.z, progbar=False)) chain_samples = fori_collect_map(init_states) assert_allclose(np.mean(chain_samples, axis=1), np.repeat(true_mean, 2), rtol=0.05) assert_allclose(np.std(chain_samples, axis=1), np.repeat(true_std, 2), rtol=0.05)
def mcmc(num_warmup, num_samples, init_params, sampler='hmc', constrain_fn=None, print_summary=True, **sampler_kwargs): """ Convenience wrapper for MCMC samplers -- runs warmup, prints diagnostic summary and returns a collections of samples from the posterior. :param num_warmup: Number of warmup steps. :param num_samples: Number of samples to generate from the Markov chain. :param init_params: Initial parameters to begin sampling. The type can must be consistent with the input type to `potential_fn`. :param sampler: currently, only `hmc` is implemented (default). :param constrain_fn: Callable that converts a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites. :param print_summary: Whether to print diagnostics summary for each sample site. Default is ``True``. :param `**sampler_kwargs`: Sampler specific keyword arguments. - *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and :func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note that all arguments must be provided as keywords. :return: collection of samples from the posterior. .. testsetup:: import jax from jax import random import jax.numpy as np import numpyro.distributions as dist from numpyro.handlers import sample from numpyro.hmc_util import initialize_model from numpyro.mcmc import hmc from numpyro.util import fori_collect .. doctest:: >>> true_coefs = np.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): ... coefs_mean = np.zeros(dim) ... coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3))) ... intercept = sample('intercept', dist.Normal(0., 10.)) ... return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), model, ... data, labels) >>> num_warmup, num_samples = 1000, 1000 >>> samples = mcmc(num_warmup, num_samples, init_params, ... potential_fn=potential_fn, ... constrain_fn=constrain_fn) # doctest: +SKIP warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79] sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85] mean sd 5.5% 94.5% n_eff Rhat coefs[0] 0.96 0.07 0.85 1.07 455.35 1.01 coefs[1] 2.05 0.09 1.91 2.20 332.00 1.01 coefs[2] 3.18 0.13 2.96 3.37 320.27 1.00 intercept -0.03 0.02 -0.06 0.00 402.53 1.00 """ if sampler == 'hmc': if constrain_fn is None: constrain_fn = identity potential_fn = sampler_kwargs.pop('potential_fn') kinetic_fn = sampler_kwargs.pop('kinetic_fn', None) algo = sampler_kwargs.pop('algo', 'NUTS') progbar = sampler_kwargs.pop('progbar', True) init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo) hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, **sampler_kwargs) samples = fori_collect(num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=progbar, diagnostics_fn=get_diagnostics_str, progbar_desc='sample') if print_summary: summary(samples) return samples else: raise ValueError('sampler: {} not recognized'.format(sampler))
def mcmc(num_warmup, num_samples, init_params, num_chains=1, sampler='hmc', constrain_fn=None, print_summary=True, **sampler_kwargs): """ Convenience wrapper for MCMC samplers -- runs warmup, prints diagnostic summary and returns a collections of samples from the posterior. :param num_warmup: Number of warmup steps. :param num_samples: Number of samples to generate from the Markov chain. :param init_params: Initial parameters to begin sampling. The type can must be consistent with the input type to `potential_fn`. :param sampler: currently, only `hmc` is implemented (default). :param constrain_fn: Callable that converts a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites. :param print_summary: Whether to print diagnostics summary for each sample site. Default is ``True``. :param `**sampler_kwargs`: Sampler specific keyword arguments. - *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and :func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note that all arguments must be provided as keywords. :return: collection of samples from the posterior. .. testsetup:: import jax from jax import random import jax.numpy as np import numpyro.distributions as dist from numpyro.handlers import sample from numpyro.hmc_util import initialize_model from numpyro.mcmc import hmc from numpyro.util import fori_collect .. doctest:: >>> true_coefs = np.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): ... coefs_mean = np.zeros(dim) ... coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3))) ... intercept = sample('intercept', dist.Normal(0., 10.)) ... return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), model, ... data, labels) >>> num_warmup, num_samples = 1000, 1000 >>> samples = mcmc(num_warmup, num_samples, init_params, ... potential_fn=potential_fn, ... constrain_fn=constrain_fn) # doctest: +SKIP warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79] sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85] mean sd 5.5% 94.5% n_eff Rhat coefs[0] 0.96 0.07 0.85 1.07 455.35 1.01 coefs[1] 2.05 0.09 1.91 2.20 332.00 1.01 coefs[2] 3.18 0.13 2.96 3.37 320.27 1.00 intercept -0.03 0.02 -0.06 0.00 402.53 1.00 """ sequential_chain = False if xla_bridge.device_count() < num_chains: sequential_chain = True warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.' ' Chains will be drawn sequentially. If you are running `mcmc` in CPU,' ' consider to disable XLA intra-op parallelism by setting the environment' ' flag "XLA_FLAGS=--xla_force_host_platform_device_count={}".' .format(num_chains, xla_bridge.device_count(), num_chains)) progbar = sampler_kwargs.pop('progbar', True) if num_chains > 1: progbar = False if sampler == 'hmc': if constrain_fn is None: constrain_fn = identity potential_fn = sampler_kwargs.pop('potential_fn') kinetic_fn = sampler_kwargs.pop('kinetic_fn', None) algo = sampler_kwargs.pop('algo', 'NUTS') if num_chains > 1: rngs = sampler_kwargs.pop('rng', vmap(PRNGKey)(np.arange(num_chains))) else: rng = sampler_kwargs.pop('rng', PRNGKey(0)) init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo) if progbar: hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, rng=rng, **sampler_kwargs) samples_flat = fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=progbar, diagnostics_fn=get_diagnostics_str, progbar_desc='sample') samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat) else: def single_chain_mcmc(rng, init_params): hmc_state = init_kernel(init_params, num_warmup, run_warmup=False, rng=rng, **sampler_kwargs) samples = fori_collect(num_warmup, num_warmup + num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=progbar) return samples if num_chains == 1: samples_flat = single_chain_mcmc(rng, init_params) samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat) else: if sequential_chain: samples = lax.map(lambda args: single_chain_mcmc(*args), (rngs, init_params)) else: samples = pmap(single_chain_mcmc)(rngs, init_params) samples_flat = tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), samples) if print_summary: summary(samples) return samples_flat else: raise ValueError('sampler: {} not recognized'.format(sampler))