def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key): if potential_fn_gen: nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy) accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond(transition, (vv_state_new, energy_new), lambda args: args, (vv_state, energy_old), lambda args: args) return vv_state, energy, num_steps, accept_prob, diverging
def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, trajectory_length): if potential_fn_gen: nonlocal vv_update, forward_mode_ad pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad) # no need to spend too many steps if the state z has 0 size (i.e. z is empty) if len(inverse_mass_matrix) == 0: num_steps = 1 else: num_steps = _get_num_steps(step_size, trajectory_length) # makes sure trajectory length is constant, rather than step_size * num_steps step_size = trajectory_length / num_steps vv_state_new = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond(transition, (vv_state_new, energy_new), identity, (vv_state, energy_old), identity) return vv_state, energy, num_steps, accept_prob, diverging
def get_final_state(model, step_size, num_steps, q_i, p_i): vv_init, vv_update = velocity_verlet(model.potential_fn, model.kinetic_fn) vv_state = vv_init(q_i, p_i) q_f, p_f, _, _ = fori_loop(0, num_steps, lambda i, val: vv_update(step_size, args.m_inv, val), vv_state) return (q_f, p_f)
def test_beta_bernoulli(elbo): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optim.Adam(0.05) svi = SVI(model, guide, adam, elbo) svi_state = svi.init(random.PRNGKey(1), data) assert_allclose(adam.get_params(svi_state.optim_state)['alpha_q'], 0.) def body_fn(i, val): svi_state, _ = svi.update(val, data) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']), 0.8, atol=0.05, rtol=0.05)
def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe): # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree({k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=partial(potential_fn, z_hmc=hmc_sites), idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond(random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, pe) _, gibbs_sites, pe = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites, pe
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100, )) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, Trace_ELBO(num_particles=100)) svi_state = svi.init(PRNGKey(0), data) # Training loop def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, args.num_steps, body_fn, svi_state) # Report the final values of the variational parameters # in the guide after training. params = svi.get_params(svi_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert jnp.abs(params["guide_loc"] - 3.0) < 0.1
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000,)) def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) median = guide.median(params) assert_allclose(median['loc'], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['loc'][1], true_coef, rtol=0.1)
def run(self, rng_key, num_steps, *args, return_last=True, progbar=True, **kwargs): def bodyfn(i, info): svgd_state, losses = info svgd_state, loss = self.update(svgd_state, *args, **kwargs) losses = ops.index_update(losses, i, loss) return svgd_state, losses svgd_state = self.init(rng_key, *args, **kwargs) losses = np.empty((num_steps, )) if not progbar: svgd_state, losses = fori_loop(0, num_steps, bodyfn, (svgd_state, losses)) else: with tqdm.trange(num_steps) as t: for i in t: svgd_state, losses = jax.jit(bodyfn)(i, (svgd_state, losses)) t.set_description('SVGD {:.5}'.format(losses[i]), refresh=False) t.update() loss_res = losses[-1] if return_last else losses return svgd_state, loss_res
def test_beta_bernoulli(auto_class): data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2))) numpyro.sample("obs", dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior( random.PRNGKey(1), params, sample_shape=(1000,) ) assert_allclose(jnp.mean(posterior_samples["beta"], 0), true_coefs, atol=0.05) # Predictive can be instantiated from posterior samples... predictive = Predictive(model, posterior_samples=posterior_samples) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, 2) # ... or from the guide + params predictive = Predictive(model, guide=guide, params=params, num_samples=1000) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, 2)
def test_beta_bernoulli(auto_class): data = np.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = numpyro.sample('beta', dist.Beta(np.ones(2), np.ones(2))) numpyro.sample('obs', dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = auto_class(model) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(random.PRNGKey(1), model_args=(data, ), guide_args=(data, )) def body_fn(i, val): svi_state, loss = svi.update(val, model_args=(data, ), guide_args=(data, )) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (np.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) assert_allclose(np.mean(posterior_samples['beta'], 0), true_coefs, atol=0.04)
def test_logistic_regression(auto_class): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_strategy=init_strategy) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data, labels) def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,)) assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size): # idx: current index of `z_discrete_flat` to update # support_size: support size of z_discrete at the index idx z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) # Here we loop over the support of z_flat[idx] to get z_new # XXX: we can't vmap potential_fn over all proposals and sample from the conditional # categorical distribution because support_size is a traced value, i.e. its value # might change across different discrete variables; # so here we will loop over all proposals and use an online scheme to sample from # the conditional categorical distribution body_fn = partial( _discrete_gibbs_proposal_body_fn, z_discrete_flat, unravel_fn, pe, potential_fn, idx, ) init_val = (rng_key, z_discrete, pe, jnp.array(0.0)) rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val) log_accept_ratio = jnp.array(0.0) return rng_key, z_new, pe_new, log_accept_ratio
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_init = random.PRNGKey(1) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(rng_init, model_args=(data, ), guide_args=(data, )) def body_fn(i, val): svi_state, loss = svi.update(val, model_args=(data, ), guide_args=(data, )) return svi_state svi_state = fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) median = guide.median(params) assert_allclose(median['loc'], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['loc'][1], true_coef, rtol=0.1)
def _discrete_modified_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0): assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1 z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) body_fn = partial( _discrete_gibbs_proposal_body_fn, z_discrete_flat, unravel_fn, pe, potential_fn, idx, ) # like gibbs_step but here, weight of the current value is 0 init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf)) rng_key, z_new, pe_new, log_weight_sum = fori_loop(0, support_size - 1, body_fn, init_val) rng_key, rng_stay = random.split(rng_key) z_new, pe_new = cond( random.bernoulli(rng_stay, stay_prob), (z_discrete, pe), identity, (z_new, pe_new), identity, ) # here we calculate the MH correction: (1 - P(z)) / (1 - P(z_new)) # where 1 - P(z) ~ weight_sum # and 1 - P(z_new) ~ 1 + weight_sum - z_new_weight log_accept_ratio = log_weight_sum - jnp.log( jnp.exp(log_weight_sum) - jnp.expm1(pe - pe_new)) return rng_key, z_new, pe_new, log_accept_ratio
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100,)) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. opt_init, opt_update, get_params = optimizers.adam(args.learning_rate) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) rng, rng_init = random.split(PRNGKey(0)) opt_state, _ = svi_init(rng_init, model_args=(data,)) # Training loop def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data,)) return opt_state_, rng_ opt_state, _ = fori_loop(0, args.num_steps, body_fn, (opt_state, rng)) # Report the final values of the variational parameters # in the guide after training. params = get_params(opt_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert np.abs(params["guide_loc"] - 3.0) < 0.1
def get_cov(x): wc_init, wc_update, wc_final = welford_covariance( diagonal=diagonal) wc_state = wc_init(3) wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val), wc_state) cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize) return cov, cov_inv_sqrt
def test_mnist_data_load(): def mean_pixels(i, mean_pix): batch, _ = fetch(i, idx) return mean_pix + jnp.sum(batch) / batch.size init, fetch = load_dataset(MNIST, batch_size=128, split='train') num_batches, idx = init() assert fori_loop(0, num_batches, mean_pixels, jnp.float32(0.)) / num_batches < 0.15
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # convert to unconstrained values z_hmc = { k: biject_to(prototype_trace[k]["fn"].support).inv(v) for k, v in hmc_sites.items() if k in prototype_trace and prototype_trace[k]["type"] == "sample" } use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 wrapped_model = _wrap_model(model) if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def test_logistic_regression(auto_class, Elbo): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample("coefs", dist.Normal(0, 1).expand([dim]).to_event()) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) with numpyro.plate("N", len(data)): return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Elbo()) svi_state = svi.init(rng_key_init, data, labels) # smoke test if analytic KL is used if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO: _, mean_field_loss = svi.update(svi_state, data, labels) svi.loss = Trace_ELBO() _, elbo_loss = svi.update(svi_state, data, labels) svi.loss = TraceMeanField_ELBO() assert abs(mean_field_loss - elbo_loss) > 0.5 def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoDAIS, AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median["coefs"], true_coefs, rtol=0.1) # test .quantile method if auto_class is not AutoDelta: median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median["coefs"][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(posterior_samples["coefs"], 0), expected_coefs, rtol=0.1)
def test_beta_bernoulli(auto_class): data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T N = len(data) def model(data): f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2)).to_event()) with numpyro.plate("N", N): numpyro.sample("obs", dist.Bernoulli(f).to_event(1), obs=data) adam = optim.Adam(0.01) if auto_class == AutoDAIS: guide = auto_class(model, init_loc_fn=init_strategy, base_dist="cholesky") else: guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) posterior_mean = jnp.mean(posterior_samples["beta"], 0) assert_allclose(posterior_mean, true_coefs, atol=0.05) if auto_class not in [AutoDAIS, AutoDelta, AutoIAFNormal, AutoBNAFNormal]: quantiles = guide.quantiles(params, [0.2, 0.5, 0.8]) assert quantiles["beta"].shape == (3, 2) # Predictive can be instantiated from posterior samples... predictive = Predictive(model, posterior_samples=posterior_samples) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, N, 2) # ... or from the guide + params predictive = Predictive(model, guide=guide, params=params, num_samples=1000) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, N, 2)
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): z_hmc = hmc_sites use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model_ = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) else: wrapped_model_ = wrapped_model def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model_, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng): num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop(0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy) accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) transition = random.bernoulli(rng, accept_prob) vv_state = cond(transition, vv_state_new, lambda state: state, vv_state, lambda state: state) return vv_state, num_steps, accept_prob
def init_kernel(init_samples, num_warmup_steps, step_size=1.0, num_steps=None, adapt_step_size=True, adapt_mass_matrix=True, diag_mass=True, target_accept_prob=0.8, run_warmup=True, rng=PRNGKey(0)): step_size = float(step_size) nonlocal trajectory_length, momentum_generator, wa_update if num_steps is None: trajectory_length = 2 * math.pi else: trajectory_length = num_steps * step_size z = init_samples z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup_steps, find_reasonable_step_size=find_reasonable_ss, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, diag_mass=diag_mass, target_accept_prob=target_accept_prob) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.inverse_mass_matrix, rng) vv_state = vv_init(z, r) hmc_state = HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., wa_state.step_size, wa_state.inverse_mass_matrix, rng_hmc) if run_warmup: hmc_state, _ = fori_loop(0, num_warmup_steps, warmup_update, (hmc_state, wa_state)) return hmc_state else: return hmc_state, wa_state, warmup_update
def test_laplace_approximation_warning(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,)) mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3 numpyro.sample("y", dist.Normal(mu, 0.001), obs=y) x = random.normal(random.PRNGKey(0), (3,)) y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3 guide = AutoLaplaceApproximation(model) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) init_state = svi.init(random.PRNGKey(0)) svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state) params = svi.get_params(svi_state) with pytest.warns(UserWarning, match="Hessian of log posterior"): guide.sample_posterior(random.PRNGKey(1), params)
def _inverse(self, y): """ :param numpy.ndarray y: the output of the transform to be inverted """ # NOTE: Inversion is an expensive operation that scales in the dimension of the input def _update_x(i, x): mean, log_scale = self.arn(x) inverse_scale = jnp.exp( -_clamp_preserve_gradients(log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip)) x = (y - mean) * inverse_scale return x x = fori_loop(0, y.shape[-1], _update_x, jnp.zeros(y.shape)) return x
def test_logistic_regression(auto_class): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return sample('obs', dist.Bernoulli(logits=logits), obs=labels) opt_init, opt_update, get_params = optimizers.adam(0.01) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = auto_class(rng_guide, model, get_params) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, labels), guide_args=(data, labels)) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, labels), guide_args=(data, labels)) return opt_state_, rng_ opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train)) if auto_class is not AutoIAFNormal: median = guide.median(opt_state) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method median = guide.quantiles(opt_state, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), opt_state, sample_shape=(1000, )) assert_allclose(np.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def test_dynamic_constraints(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def model(data): # NB: model's constraints will play no effect loc = param('loc', 0., constraint=constraints.interval(0, 0.5)) sample('obs', dist.Normal(loc, 0.1), obs=data) def guide(): alpha = param('alpha', 0.5, constraint=constraints.unit_interval) param('loc', 0, constraint=constraints.interval(0, alpha)) opt_init, opt_update, get_params = optimizers.adam(0.05) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) rng_init, rng_train = random.split(random.PRNGKey(1)) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, )) return opt_state_, rng_ opt_state, rng = fori_loop(0, 300, body_fn, (opt_state, rng_train)) params = get_param(opt_state, model, guide, get_params, constrain_fn, rng, guide_args=()) assert_allclose(params['loc'], true_coef, atol=0.05)
def sample(self, state, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs num_discretes = self._support_sizes_flat.shape[0] def potential_fn(z_gibbs, z_hmc): return self.inner_kernel._potential_fn_gen(*model_args, _gibbs_sites=z_gibbs, **model_kwargs)(z_hmc) def update_discrete(idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum): # Algo 1, line 19: get a new discrete proposal ( rng_key, z_discrete_new, pe_new, log_accept_ratio, ) = self._discrete_proposal_fn( rng_key, z_discrete, hmc_state.potential_energy, partial(potential_fn, z_hmc=hmc_state.z), idx, self._support_sizes_flat[idx], ) # Algo 1, line 20: depending on reject or refract, we will update # the discrete variable and its corresponding kinetic energy. In case of # refract, we will need to update the potential energy and its grad w.r.t. hmc_state.z ke_discrete_i_new = ke_discrete[idx] + log_accept_ratio grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad z_discrete, pe, ke_discrete_i, z_grad = lax.cond( ke_discrete_i_new > 0, (z_discrete_new, pe_new, ke_discrete_i_new), lambda vals: vals + (grad_(partial(potential_fn, vals[0])) (hmc_state.z), ), ( z_discrete, hmc_state.potential_energy, ke_discrete[idx], hmc_state.z_grad, ), identity, ) delta_pe_sum = delta_pe_sum + pe - hmc_state.potential_energy ke_discrete = ops.index_update(ke_discrete, idx, ke_discrete_i) hmc_state = hmc_state._replace(potential_energy=pe, z_grad=z_grad) return rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum def update_continuous(hmc_state, z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete hmc_state_new = self.inner_kernel.sample(hmc_state, model_args, model_kwargs_) # each time a sub-trajectory is performed, we need to reset i and adapt_state # (we will only update them at the end of HMCGibbs step) # For `num_steps`, we will record its cumulative sum for diagnostics hmc_state = hmc_state_new._replace( i=hmc_state.i, adapt_state=hmc_state.adapt_state, num_steps=hmc_state.num_steps + hmc_state_new.num_steps, ) return hmc_state def body_fn(i, vals): ( rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times, ) = vals idx = jnp.argmin(arrival_times) # NB: length of each sub-trajectory is scaled from the current min(arrival_times) # (see the note at total_time below) trajectory_length = arrival_times[idx] * time_unit arrival_times = arrival_times - arrival_times[idx] arrival_times = ops.index_update(arrival_times, idx, 1.0) # this is a trick, so that in a sub-trajectory of HMC, we always accept the new proposal pe = jnp.inf hmc_state = hmc_state._replace(trajectory_length=trajectory_length, potential_energy=pe) # Algo 1, line 7: perform a sub-trajectory hmc_state = update_continuous(hmc_state, z_discrete) # Algo 1, line 8: perform a discrete update rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum = update_discrete( idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum) return ( rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times, ) z_discrete = { k: v for k, v in state.z.items() if k not in state.hmc_state.z } rng_key, rng_ke, rng_time, rng_r, rng_accept = random.split( state.rng_key, 5) # Algo 1, line 2: sample discrete kinetic energy ke_discrete = random.exponential(rng_ke, (num_discretes, )) # Algo 1, line 4 and 5: sample the initial amount of time that each discrete site visits # the point 0/1. The logic in GetStepSizesNSteps(...) is more complicated but does # the same job: the sub-trajectory length eta_t * M_t is the lag between two arrival time. arrival_times = random.uniform(rng_time, (num_discretes, )) # compute the amount of time to make `num_discrete_updates` discrete updates total_time = (self._num_discrete_updates - 1) // num_discretes + jnp.sort(arrival_times)[ (self._num_discrete_updates - 1) % num_discretes] # NB: total_time can be different from the HMC trajectory length, so we need to scale # the time unit so that total_time * time_unit = hmc_trajectory_length time_unit = state.hmc_state.trajectory_length / total_time # Algo 1, line 2: sample hmc momentum r = momentum_generator(state.hmc_state.r, state.hmc_state.adapt_state.mass_matrix_sqrt, rng_r) hmc_state = state.hmc_state._replace(r=r, num_steps=0) hmc_ke = euclidean_kinetic_energy( hmc_state.adapt_state.inverse_mass_matrix, r) # Algo 1, line 10: compute the initial energy energy_old = hmc_ke + hmc_state.potential_energy # Algo 1, line 3: set initial values delta_pe_sum = 0.0 init_val = ( rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times, ) # Algo 1, line 6-9: perform the update loop rng_key, hmc_state_new, z_discrete_new, _, delta_pe_sum, _ = fori_loop( 0, self._num_discrete_updates, body_fn, init_val) # Algo 1, line 10: compute the proposal energy hmc_ke = euclidean_kinetic_energy( hmc_state.adapt_state.inverse_mass_matrix, hmc_state_new.r) energy_new = hmc_ke + hmc_state_new.potential_energy # Algo 1, line 11: perform MH correction delta_energy = energy_new - energy_old - delta_pe_sum delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) # record the correct new num_steps hmc_state = hmc_state._replace(num_steps=hmc_state_new.num_steps) # reset the trajectory length hmc_state_new = hmc_state_new._replace( trajectory_length=hmc_state.trajectory_length) hmc_state, z_discrete = cond( random.bernoulli(rng_key, accept_prob), (hmc_state_new, z_discrete_new), identity, (hmc_state, z_discrete), identity, ) # perform hmc adapting (similar to the implementation in hmc) adapt_state = cond( hmc_state.i < self._num_warmup, (hmc_state.i, accept_prob, (hmc_state.z, ), hmc_state.adapt_state), lambda args: self._wa_update(*args), hmc_state.adapt_state, identity, ) itr = hmc_state.i + 1 n = jnp.where(hmc_state.i < self._num_warmup, itr, itr - self._num_warmup) mean_accept_prob_prev = state.hmc_state.mean_accept_prob mean_accept_prob = (mean_accept_prob_prev + (accept_prob - mean_accept_prob_prev) / n) hmc_state = hmc_state._replace( i=itr, accept_prob=accept_prob, mean_accept_prob=mean_accept_prob, adapt_state=adapt_state, ) z = {**z_discrete, **hmc_state.z} return MixedHMCState(z, hmc_state, rng_key, accept_prob)
def init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2*math.pi, max_tree_depth=10, run_warmup=True, progbar=True, rng=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup_steps: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, `init_kernel` returns an initial :data:`HMCState` that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation. :param bool progbar: Whether to enable progress bar updates. Defaults to ``True``. :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of step size is done at the beginning of each adaptation window to achieve `target_acceptance_prob`. :param jax.random.PRNGKey rng: random key to be used as the source of randomness. """ step_size = float(step_size) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = float(trajectory_length) max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter(num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng) vv_state = vv_init(z, r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0., wa_state, rng_hmc) if run_warmup and num_warmup > 0: # JIT if progress bar updates not required if not progbar: hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state) else: with tqdm.trange(num_warmup, desc='warmup') as t: for i in t: hmc_state = sample_kernel(hmc_state) t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False) return hmc_state