def test_dynamic_supports(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000,)) def actual_model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha) ), ) with numpyro.plate("N", len(data)): numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) def expected_model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) loc = numpyro.sample("loc", dist.Uniform(0, 1)) * alpha with numpyro.plate("N", len(data)): numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(actual_model) svi = SVI(actual_model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data) actual_opt_params = adam.get_params(svi_state.optim_state) actual_params = svi.get_params(svi_state) actual_values = guide.median(actual_params) actual_loss = svi.evaluate(svi_state, data) guide = AutoDiagonalNormal(expected_model) svi = SVI(expected_model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data) expected_opt_params = adam.get_params(svi_state.optim_state) expected_params = svi.get_params(svi_state) expected_values = guide.median(expected_params) expected_loss = svi.evaluate(svi_state, data) # test auto_loc, auto_scale check_eq(actual_opt_params, expected_opt_params) check_eq(actual_params, expected_params) # test latent values assert_allclose(actual_values["alpha"], expected_values["alpha"]) assert_allclose(actual_values["loc_base"], expected_values["loc"]) assert_allclose(actual_loss, expected_loss)
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key_init, data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) median = guide.median(params) assert_allclose(median['loc'], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['loc'][1], true_coef, rtol=0.1)
def test_beta_bernoulli(elbo): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optax.adam(0.05) svi = SVI(model, guide, adam, elbo) svi_state = svi.init(random.PRNGKey(1), data) assert_allclose( svi.optim.get_params(svi_state.optim_state)["alpha_q"], 0.0) def body_fn(i, val): svi_state, _ = svi.update(val, data) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) assert_allclose( params["alpha_q"] / (params["alpha_q"] + params["beta_q"]), 0.8, atol=0.05, rtol=0.05, )
def test_beta_bernoulli(auto_class): data = np.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = numpyro.sample('beta', dist.Beta(np.ones(2), np.ones(2))) numpyro.sample('obs', dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = auto_class(model, init_strategy=init_strategy) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (np.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) assert_allclose(np.mean(posterior_samples['beta'], 0), true_coefs, atol=0.04)
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000,)) def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha) ), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) median = guide.median(params) assert_allclose(median["loc"], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median["loc"][1], true_coef, rtol=0.1)
def test_predictive_with_guide(): data = jnp.array([1] * 8 + [0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) with numpyro.plate("plate", 10): numpyro.deterministic("beta_sq", f**2) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, _ = svi.update(val, data) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) predictive = Predictive(model, guide=guide, params=params, num_samples=1000)(random.PRNGKey(2), data=None) assert predictive["beta_sq"].shape == (1000, ) obs_pred = predictive["obs"].astype(np.float32) assert_allclose(jnp.mean(obs_pred), 0.8, atol=0.05)
def test_beta_bernoulli(auto_class): data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2))) numpyro.sample("obs", dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior( random.PRNGKey(1), params, sample_shape=(1000,) ) assert_allclose(jnp.mean(posterior_samples["beta"], 0), true_coefs, atol=0.05) # Predictive can be instantiated from posterior samples... predictive = Predictive(model, posterior_samples=posterior_samples) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, 2) # ... or from the guide + params predictive = Predictive(model, guide=guide, params=params, num_samples=1000) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, 2)
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100,)) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, ELBO(num_particles=100)) svi_state = svi.init(PRNGKey(0), data) # Training loop def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, args.num_steps, body_fn, svi_state) # Report the final values of the variational parameters # in the guide after training. params = svi.get_params(svi_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert np.abs(params["guide_loc"] - 3.0) < 0.1
def run_inference(model, inputs, method=None): if method is None: # NUTS num_samples = 5000 logger.info('NUTS sampling') kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=300, num_samples=num_samples) rng_key = random.PRNGKey(0) mcmc.run(rng_key, **inputs, extra_fields=('potential_energy', )) logger.info(r'MCMC summary for: {}'.format(model.__name__)) mcmc.print_summary(exclude_deterministic=False) samples = mcmc.get_samples() else: #SVI logger.info('Guide generation...') rng_key = random.PRNGKey(0) guide = AutoDiagonalNormal(model=model) logger.info('Optimizer generation...') optim = Adam(0.05) logger.info('SVI generation...') svi = SVI(model, guide, optim, AutoContinuousELBO(), **inputs) init_state = svi.init(rng_key) logger.info('Scan...') state, loss = lax.scan(lambda x, i: svi.update(x), init_state, np.zeros(2000)) params = svi.get_params(state) samples = guide.sample_posterior(random.PRNGKey(1), params, (1000, )) logger.info(r'SVI summary for: {}'.format(model.__name__)) numpyro.diagnostics.print_summary(samples, prob=0.90, group_by_chain=False) return samples
def test_logistic_regression(auto_class): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_strategy=init_strategy) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data, labels) def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,)) assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def test_jitted_update_fn(): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optim.Adam(0.05) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) expected = svi.get_params(svi.update(svi_state, data)[0]) actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0]) check_close(actual, expected, atol=1e-5)
def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) offset = numpyro.sample("offset", dist.Uniform(-1, 1)) logits = offset + jnp.sum(coefs * data, axis=-1) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1, )) rng_key = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng_key, params) actual_output = guide._unpack_latent(guide.get_transform(params)(x)) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(transforms.PermuteTransform( jnp.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN( dim + 1, [dim + 1, dim + 1], permutation=jnp.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity, ) arn = partial(arn_apply, params["auto_arn__{}$params".format(i)]) flows.append(InverseAutoregressiveTransform(arn)) flows.append(guide._unpack_latent) transform = transforms.ComposeTransform(flows) _, rng_key_sample = random.split(rng_key) expected_sample = transform( dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample)) expected_output = transform(x) assert_allclose(actual_sample["coefs"], expected_sample["coefs"]) assert_allclose( actual_sample["offset"], transforms.biject_to(constraints.interval(-1, 1))( expected_sample["offset"]), ) check_eq(actual_output, expected_output)
def test_neutra_reparam_unobserved_model(): model = dirichlet_categorical data = jnp.ones(10, dtype=jnp.int32) guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-3), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), data) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) with handlers.seed(rng_seed=0): reparam_model(data=None)
def test_dynamic_supports(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def actual_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) def expected_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(actual_model) svi = SVI(actual_model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key_init, data) actual_opt_params = adam.get_params(svi_state.optim_state) actual_params = svi.get_params(svi_state) actual_values = guide.median(actual_params) actual_loss = svi.evaluate(svi_state, data) guide = AutoDiagonalNormal(expected_model) svi = SVI(expected_model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key_init, data) expected_opt_params = adam.get_params(svi_state.optim_state) expected_params = svi.get_params(svi_state) expected_values = guide.median(expected_params) expected_loss = svi.evaluate(svi_state, data) # test auto_loc, auto_scale check_eq(actual_opt_params, expected_opt_params) check_eq(actual_params, expected_params) # test latent values assert_allclose(actual_values['alpha'], expected_values['alpha']) assert_allclose(actual_values['loc'], expected_values['alpha'] * expected_values['loc']) assert_allclose(actual_loss, expected_loss)
def svi(model, guide, num_steps, lr, rng_key, X, Y): """ Helper function for doing SVI inference. """ svi = SVI(model, guide, optim.Adam(lr), ELBO(num_particles=1), X=X, Y=Y) svi_state = svi.init(rng_key) print('Optimizing...') state, loss = lax.scan(lambda x, i: svi.update(x), svi_state, np.zeros(num_steps)) return loss, svi.get_params(state)
def test_logistic_regression(auto_class, Elbo): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample("coefs", dist.Normal(0, 1).expand([dim]).to_event()) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) with numpyro.plate("N", len(data)): return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Elbo()) svi_state = svi.init(rng_key_init, data, labels) # smoke test if analytic KL is used if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO: _, mean_field_loss = svi.update(svi_state, data, labels) svi.loss = Trace_ELBO() _, elbo_loss = svi.update(svi_state, data, labels) svi.loss = TraceMeanField_ELBO() assert abs(mean_field_loss - elbo_loss) > 0.5 def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoDAIS, AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median["coefs"], true_coefs, rtol=0.1) # test .quantile method if auto_class is not AutoDelta: median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median["coefs"][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(posterior_samples["coefs"], 0), expected_coefs, rtol=0.1)
def test_beta_bernoulli(auto_class): data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T N = len(data) def model(data): f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2)).to_event()) with numpyro.plate("N", N): numpyro.sample("obs", dist.Bernoulli(f).to_event(1), obs=data) adam = optim.Adam(0.01) if auto_class == AutoDAIS: guide = auto_class(model, init_loc_fn=init_strategy, base_dist="cholesky") else: guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) posterior_mean = jnp.mean(posterior_samples["beta"], 0) assert_allclose(posterior_mean, true_coefs, atol=0.05) if auto_class not in [AutoDAIS, AutoDelta, AutoIAFNormal, AutoBNAFNormal]: quantiles = guide.quantiles(params, [0.2, 0.5, 0.8]) assert quantiles["beta"].shape == (3, 2) # Predictive can be instantiated from posterior samples... predictive = Predictive(model, posterior_samples=posterior_samples) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, N, 2) # ... or from the guide + params predictive = Predictive(model, guide=guide, params=params, num_samples=1000) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, N, 2)
def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-10), Trace_ELBO(), **kwargs) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model(random.PRNGKey(2), reparam_model, model_kwargs=kwargs) latent_x = list(init_params[0].values())[0] pe_transformed = pe_fn_neutra(init_params[0]) latent_y = neutra.transform(latent_x) log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn(guide._unpack_latent(latent_y)) assert_allclose(pe_transformed, pe - log_det_jacobian)
def test_collapse_beta_binomial(): total_count = 10 data = 3. def model1(): c1 = numpyro.param("c1", 0.5, constraint=dist.constraints.positive) c0 = numpyro.param("c0", 1.5, constraint=dist.constraints.positive) with handlers.collapse(): probs = numpyro.sample("probs", dist.Beta(c1, c0)) numpyro.sample("obs", dist.Binomial(total_count, probs), obs=data) def model2(): c1 = numpyro.param("c1", 0.5, constraint=dist.constraints.positive) c0 = numpyro.param("c0", 1.5, constraint=dist.constraints.positive) numpyro.sample("obs", dist.BetaBinomial(c1, c0, total_count), obs=data) trace1 = handlers.trace(model1).get_trace() trace2 = handlers.trace(model2).get_trace() assert "probs" in trace1 assert "obs" not in trace1 assert "probs" not in trace2 assert "obs" in trace2 svi1 = SVI(model1, lambda: None, numpyro.optim.Adam(1), Trace_ELBO()) svi2 = SVI(model2, lambda: None, numpyro.optim.Adam(1), Trace_ELBO()) svi_state1 = svi1.init(random.PRNGKey(0)) svi_state2 = svi2.init(random.PRNGKey(0)) params1 = svi1.get_params(svi_state1) params2 = svi2.get_params(svi_state2) assert_allclose(params1["c1"], params2["c1"]) assert_allclose(params1["c0"], params2["c0"]) params1 = svi1.get_params(svi1.update(svi_state1)[0]) params2 = svi2.get_params(svi2.update(svi_state2)[0]) assert_allclose(params1["c1"], params2["c1"]) assert_allclose(params1["c0"], params2["c0"])
def test_laplace_approximation_warning(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,)) mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3 numpyro.sample("y", dist.Normal(mu, 0.001), obs=y) x = random.normal(random.PRNGKey(0), (3,)) y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3 guide = AutoLaplaceApproximation(model) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) init_state = svi.init(random.PRNGKey(0)) svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state) params = svi.get_params(svi_state) with pytest.warns(UserWarning, match="Hessian of log posterior"): guide.sample_posterior(random.PRNGKey(1), params)
def run_svi_inference(model, guide, rng_key, X, Y, optimizer, n_epochs=1_000): # initialize svi svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) # initialize state init_state = svi.init(rng_key, X, Y.squeeze()) # Run optimizer for 1000 iteratons. state, losses = jax.lax.scan( lambda state, i: svi.update(state, X, Y.squeeze()), init_state, n_epochs ) # Extract surrogate posterior. params = svi.get_params(state) return params
def test_autoguide(deterministic): GLOBAL["count"] = 0 guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), deterministic=deterministic) svi_state = svi.init(random.PRNGKey(0)) svi_state = lax.fori_loop(0, 100, lambda i, val: svi.update(val)[0], svi_state) params = svi.get_params(svi_state) guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(100, )) if deterministic: assert GLOBAL["count"] == 5 else: assert GLOBAL["count"] == 4
def test_get_params(kernel, auto_guide, init_loc_fn, problem): _, data, model = problem() guide, optim, elbo = ( auto_guide(model, init_loc_fn=init_loc_fn), Adam(1e-1), Trace_ELBO(), ) stein = SteinVI(model, guide, optim, elbo, kernel) stein_params = stein.get_params(stein.init(random.PRNGKey(0), *data)) svi = SVI(model, guide, optim, elbo) svi_params = svi.get_params(svi.init(random.PRNGKey(0), *data)) assert svi_params.keys() == stein_params.keys() for name, svi_param in svi_params.items(): assert (stein_params[name].shape == np.repeat(svi_param[None, ...], stein.num_particles, axis=0).shape)
def test_param(): # this test the validity of model having # param sites contain composed transformed constraints rng_keys = random.split(random.PRNGKey(0), 3) a_minval = 1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) x_init = random.normal(rng_keys[2]) def model(): a = numpyro.param("a", a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param("b", b_init, constraint=constraints.positive) numpyro.sample("x", dist.Normal(a, b)) # this class is used to force init value of `x` to x_init class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {"_auto_latent": x_init[None]})(*args, **kwargs) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = _AutoGuide(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init) params = svi.get_params(svi_state) assert_allclose(params["a"], a_init, rtol=1e-6) assert_allclose(params["b"], b_init, rtol=1e-6) assert_allclose(params["auto_loc"], guide._init_latent, rtol=1e-6) assert_allclose(params["auto_scale"], jnp.ones(1) * guide._init_scale, rtol=1e-6) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal( guide._init_latent, guide._init_scale).log_prob(x_init) - dist.Normal( a_init, b_init).log_prob(x_init) assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_param(): # this test the validity of model/guide sites having # param constraints contain composed transformed rng_keys = random.split(random.PRNGKey(0), 5) a_minval = 1 c_minval = -2 c_maxval = -1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval) d_init = random.uniform(rng_keys[3]) obs = random.normal(rng_keys[4]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b), obs=obs) def guide(): c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval)) d = numpyro.param('d', d_init, constraint=constraints.unit_interval) numpyro.sample('y', dist.Normal(c, d), obs=obs) adam = optim.Adam(0.01) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['c'], c_init) assert_allclose(params['d'], d_init) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal( a_init, b_init).log_prob(obs) # not so precisely because we do transform / inverse transform stuffs assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_neals_funnel_smoke(): dim = 10 guide = AutoIAFNormal(neals_funnel) svi = SVI(neals_funnel, guide, Adam(1e-10), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), dim) def body_fn(i, val): svi_state, loss = svi.update(val, dim) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) model = neutra.reparam(neals_funnel) nuts = NUTS(model) mcmc = MCMC(nuts, num_warmup=50, num_samples=50) mcmc.run(random.PRNGKey(1), dim) samples = mcmc.get_samples() transformed_samples = neutra.transform_sample(samples['auto_shared_latent']) assert 'x' in transformed_samples assert 'y' in transformed_samples
class ModelHandler(object): def __init__(self, model: Model, guide: Guide, rng_key: int = 0, *, loss: ELBO = ELBO(num_particles=1), optim_builder: optim.optimizers.optimizer = optim.Adam): """Handling the model and guide for training and prediction Args: model: function holding the numpyro model guide: function holding the numpyro guide rng_key: random key as int loss: loss to optimize optim_builder: builder for an optimizer """ self.model = model self.guide = guide self.rng_key = random.PRNGKey(rng_key) # current random key self.loss = loss self.optim_builder = optim_builder self.svi = None self.svi_state = None self.optim = None self.log_func = print # overwrite e.g. logger.info(...) def reset_svi(self): """Reset the current SVI state""" self.svi = None self.svi_state = None return self def init_svi(self, X: DeviceArray, *, lr: float, **kwargs): """Initialize the SVI state Args: X: input data lr: learning rate kwargs: other keyword arguments for optimizer """ self.optim = self.optim_builder(lr, **kwargs) self.svi = SVI(self.model, self.guide, self.optim, self.loss) svi_state = self.svi.init(self.rng_key, X) if self.svi_state is None: self.svi_state = svi_state return self @property def optim_state(self) -> OptimizerState: """Current optimizer state""" assert self.svi_state is not None, "'init_svi' needs to be called first" return self.svi_state.optim_state @optim_state.setter def optim_state(self, state: OptimizerState): """Set current optimizer state""" self.svi_state = SVIState(state, self.rng_key) def dump_optim_state(self, fh: IO): """Pickle and dump optimizer state to file handle""" pickle.dump( optim.optimizers.unpack_optimizer_state(self.optim_state[1]), fh) return self def load_optim_state(self, fh: IO): """Read and unpickle optimizer state from file handle""" state = optim.optimizers.pack_optimizer_state(pickle.load(fh)) iter0 = jnp.array(0) self.optim_state = (iter0, state) return self @property def optim_total_steps(self) -> int: """Returns the number of performed iterations in total""" return int(self.optim_state[0]) def _fit(self, X: DeviceArray, n_epochs) -> float: @jit def train_epochs(svi_state, n_epochs): def train_one_epoch(_, val): loss, svi_state = val svi_state, loss = self.svi.update(svi_state, X) return loss, svi_state return lax.fori_loop(0, n_epochs, train_one_epoch, (0., svi_state)) loss, self.svi_state = train_epochs(self.svi_state, n_epochs) return float(loss / X.shape[0]) def _log(self, n_digits, epoch, loss): msg = f"epoch: {str(epoch).rjust(n_digits)} loss: {loss: 16.4f}" self.log_func(msg) def fit(self, X: DeviceArray, *, n_epochs: int, log_freq: int = 0, lr: float, **kwargs) -> float: """Train but log with a given frequency Args: X: input data n_epochs: total number of epochs log_freq: log loss every log_freq number of eppochs lr: learning rate kwargs: parameters of `init_svi` Returns: final loss of last epoch """ self.init_svi(X, lr=lr, **kwargs) if log_freq <= 0: self._fit(X, n_epochs) else: loss = self.svi.evaluate(self.svi_state, X) / X.shape[0] curr_epoch = 0 n_digits = len(str(abs(n_epochs))) self._log(n_digits, curr_epoch, loss) for i in range(n_epochs // log_freq): curr_epoch += log_freq loss = self._fit(X, log_freq) self._log(n_digits, curr_epoch, loss) rest = n_epochs % log_freq if rest > 0: curr_epoch += rest loss = self._fit(X, rest) self._log(n_digits, curr_epoch, loss) loss = self.svi.evaluate(self.svi_state, X) / X.shape[0] self.rng_key = self.svi_state.rng_key return float(loss) @property def model_params(self) -> Optional[Dict[str, DeviceArray]]: """Gets model parameters Returns: dict of model parameters """ if self.svi is not None: return self.svi.get_params(self.svi_state) else: return None def predict(self, X: DeviceArray, **kwargs) -> DeviceArray: """Predict the parameters of a model specified by `return_sites` Args: X: input data kwargs: keyword arguments for numpro `Predictive` Returns: samples for all sample sites """ self.init_svi(X, lr=0.) # dummy initialization predictive = Predictive(self.model, guide=self.guide, params=self.model_params, **kwargs) samples = predictive(self.rng_key, X) return samples
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() guide = AutoBNAFNormal( dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]) svi = SVI(dual_moon_model, guide, optim.Adam(0.003), ELBO()) svi_state = svi.init(random.PRNGKey(1)) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(2), params, sample_shape=(args.num_samples, ))['x'].copy() print("\nStart NeuTra HMC...") neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(dual_moon_model) nuts_kernel = NUTS(neutra_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(3)) mcmc.print_summary() zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"] print("Transform samples into unwarped space...") samples = neutra.transform_sample(zs) print_summary(samples) zs = zs.reshape(-1, 2) samples = samples['x'].reshape(-1, 2).copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(jnp.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = neutra.transform_sample(guide_base_samples)['x'] x1 = jnp.linspace(-3, 3, 100) x2 = jnp.linspace(-3, 3, 100) X1, X2 = jnp.meshgrid(x1, x2) P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 8), constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 0]) ax3 = fig.add_subplot(gs[0, 1]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[0, 2]) ax6 = fig.add_subplot(gs[1, 2]) ax1.plot(losses[1000:]) ax1.set_title('Autoguide training loss\n(after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)' ) ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler') plt.savefig("neutra.pdf")
class SVIHandler(Handler): """ Helper object that abstracts some of numpyros complexities. Inspired by an implementation of Florian Wilhelm. :param model: A numpyro model. :param guide: A numpyro guide. :param loss: Loss function, defaults to Trace_ELBO. :param lr: Learning rate, defaults to 0.001. :param lrd: Learning rate decay per step, defaults to 1.0 (no decay) :param rng_key: Random seed, defaults to 254. :param num_epochs: Number of epochs to train the model, defaults to 5000. :param num_samples: Number of posterior samples. :param log_func: Logging function, defaults to print. :param log_freq: Frequency of logging, defaults to 0 (no logging). :param to_numpy: Convert the posterior distribution to numpy array(s), defaults to True. """ def __init__( self, model: Model, guide: Guide, loss: Trace_ELBO = Trace_ELBO(num_particles=1), optimizer: optim.optimizers.optimizer = optim.ClippedAdam, lr: float = 0.001, lrd: float = 1.0, rng_key: int = 254, num_epochs: int = 30000, num_samples: int = 1000, log_func=_print_consumer, log_freq=1000, to_numpy: bool = True, ): self.model = model self.guide = guide self.loss = loss self.optimizer = optimizer(step_size=lambda x: lr * lrd**x) self.rng_key = random.PRNGKey(rng_key) self.svi = SVI(self.model, self.guide, self.optimizer, loss=self.loss) self.init_state = None self.log_func = log_func self.log_freq = log_freq self.num_epochs = num_epochs self.num_samples = num_samples self.loss = None self.to_numpy = to_numpy def _log(self, epoch, loss, n_digits=4): msg = f"epoch: {str(epoch).rjust(n_digits)} loss: {loss: 16.4f}" self.log_func(msg) def _fit(self, *args): def _step(state, i, *args): state = lax.cond( i % self.log_freq == 0, lambda _: host_callback.id_tap(self.log_func, (i, self.num_epochs), result=state), lambda _: state, operand=None, ) return self.svi.update(state, *args) return lax.scan( lambda state, i: _step(state, i, *args), self.init_state, jnp.arange(self.num_epochs), ) def _update_state(self, state, loss): self.state = state self.init_state = state self.loss = loss if self.loss is None else jnp.concatenate( [self.loss, loss]) def fit(self, *args, **kwargs): self.num_epochs = kwargs.pop("num_epochs", self.num_epochs) predictive_kwargs = kwargs.pop("predictive_kwargs", {}) if self.init_state is None: self.init_state = self.svi.init(self.rng_key, *args) state, loss = self._fit(*args) self._update_state(state, loss) self.params = self.svi.get_params(state) predictive = Predictive( self.model, guide=self.guide, params=self.params, num_samples=self.num_samples, **predictive_kwargs, ) self.posterior = Posterior(predictive(self.rng_key, *args), self.to_numpy) return self def predict(self, *args, **kwargs): """kwargs -> Predictive, args -> predictive""" num_samples = kwargs.pop("num_samples", self.num_samples) rng_key = kwargs.pop("rng_key", self.rng_key) predictive = Predictive( self.model, guide=self.guide, params=self.params, num_samples=num_samples, **kwargs, ) self.predictive = Posterior(predictive(rng_key, *args), self.to_numpy) def dump_params(self, file_name: str): assert self.params is not None, "'init_svi' needs to be called first" pickle.dump(self.params, open(file_name, "wb")) def load_params(self, file_name): self.params = pickle.load(open(file_name, "rb"))
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() adam = optim.Adam(0.01) # TODO: it is hard to find good hyperparameters such that IAF guide can learn this model. # We will use BNAF instead! guide = AutoIAFNormal(dual_moon_model, num_flows=2, hidden_dims=[args.num_hidden, args.num_hidden]) svi = SVI(dual_moon_model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(1)) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(0), params, sample_shape=(args.num_samples, ))['x'].copy() transform = guide.get_transform(params) _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), dual_moon_model) transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform) transformed_constrain_fn = lambda x: constrain_fn(transform(x) ) # noqa: E731 print("\nStart NeuTra HMC...") nuts_kernel = NUTS(potential_fn=transformed_potential_fn) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) init_params = np.zeros(guide.latent_size) mcmc.run(random.PRNGKey(3), init_params=init_params) mcmc.print_summary() zs = mcmc.get_samples() print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) print_summary(tree_map(lambda x: x[None, ...], samples)) samples = samples['x'].copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = vmap(transformed_constrain_fn)( guide_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 16), constrained_layout=True) gs = GridSpec(3, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) ax1.plot(np.log(losses[1000:])) ax1.set_title('Autoguide training log loss (after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') plt.savefig("neutra.pdf") plt.close()