def test_laplace_approximation_warning(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,)) mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3 numpyro.sample("y", dist.Normal(mu, 0.001), obs=y) x = random.normal(random.PRNGKey(0), (3,)) y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3 guide = AutoLaplaceApproximation(model) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) init_state = svi.init(random.PRNGKey(0)) svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state) params = svi.get_params(svi_state) with pytest.warns(UserWarning, match="Hessian of log posterior"): guide.sample_posterior(random.PRNGKey(1), params)
def test_laplace_approximation_custom_hessian(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10)) mu = a + b * x with numpyro.plate("N", len(x)): numpyro.sample("y", dist.Normal(mu, 1), obs=y) x = random.normal(random.PRNGKey(0), (100,)) y = 1 + 2 * x guide = AutoLaplaceApproximation( model, hessian_fn=lambda f, x: jacobian(jacobian(f))(x) ) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) svi_result = svi.run(random.PRNGKey(0), 10000, progress_bar=False) guide.get_transform(svi_result.params)
def test_logistic_regression(auto_class, Elbo): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Elbo()) svi_state = svi.init(rng_key_init, data, labels) # smoke test if analytic KL is used if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO: _, mean_field_loss = svi.update(svi_state, data, labels) svi.loss = Trace_ELBO() _, elbo_loss = svi.update(svi_state, data, labels) svi.loss = TraceMeanField_ELBO() assert abs(mean_field_loss - elbo_loss) > 0.5 def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method if auto_class is not AutoDelta: median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-10), Trace_ELBO(), **kwargs) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model( random.PRNGKey(2), reparam_model, model_kwargs=kwargs ) latent_x = list(init_params[0].values())[0] pe_transformed = pe_fn_neutra(init_params[0]) latent_y = neutra.transform(latent_x) log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn(guide._unpack_latent(latent_y)) assert_allclose(pe_transformed, pe - log_det_jacobian, rtol=2e-7)
def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) offset = numpyro.sample('offset', dist.Uniform(-1, 1)) logits = offset + jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1,)) rng_key = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng_key, params) actual_output = guide._unpack_latent(guide.get_transform(params)(x)) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(transforms.PermuteTransform(jnp.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN(dim + 1, [dim + 1, dim + 1], permutation=jnp.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity) arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)]) flows.append(InverseAutoregressiveTransform(arn)) flows.append(guide._unpack_latent) transform = transforms.ComposeTransform(flows) _, rng_key_sample = random.split(rng_key) expected_sample = transform(dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample)) expected_output = transform(x) assert_allclose(actual_sample['coefs'], expected_sample['coefs']) assert_allclose(actual_sample['offset'], transforms.biject_to(constraints.interval(-1, 1))(expected_sample['offset'])) check_eq(actual_output, expected_output)
def test_collapse_beta_bernoulli(): data = 0. def model(): c = numpyro.sample("c", dist.Gamma(1, 1)) with handlers.collapse(): probs = numpyro.sample("probs", dist.Beta(c, 2)) numpyro.sample("obs", dist.Bernoulli(probs), obs=data) def guide(): a = numpyro.param("a", 1., constraint=constraints.positive) b = numpyro.param("b", 1., constraint=constraints.positive) numpyro.sample("c", dist.Gamma(a, b)) svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0)) svi.update(svi_state)
def test_improper(): y = random.normal(random.PRNGKey(0), (100, )) def model(y): lambda1 = numpyro.sample( 'lambda1', dist.ImproperUniform(dist.constraints.real, (), ())) lambda2 = numpyro.sample( 'lambda2', dist.ImproperUniform(dist.constraints.real, (), ())) sigma = numpyro.sample( 'sigma', dist.ImproperUniform(dist.constraints.positive, (), ())) mu = numpyro.deterministic('mu', lambda1 + lambda2) numpyro.sample('y', dist.Normal(mu, sigma), obs=y) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(0.003), Trace_ELBO(), y=y) svi_state = svi.init(random.PRNGKey(2)) lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(10000))
def run_svi_inference(model, guide, rng_key, X, Y, optimizer, n_epochs=1_000): # initialize svi svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) # initialize state init_state = svi.init(rng_key, X, Y.squeeze()) # Run optimizer for 1000 iteratons. state, losses = jax.lax.scan( lambda state, i: svi.update(state, X, Y.squeeze()), init_state, n_epochs ) # Extract surrogate posterior. params = svi.get_params(state) return params
def test_beta_bernoulli(auto_class): data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T N = len(data) def model(data): f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2)).to_event()) with numpyro.plate("N", N): numpyro.sample("obs", dist.Bernoulli(f).to_event(1), obs=data) adam = optim.Adam(0.01) if auto_class == AutoDAIS: guide = auto_class(model, init_loc_fn=init_strategy, base_dist="cholesky") else: guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior( random.PRNGKey(1), params, sample_shape=(1000,) ) posterior_mean = jnp.mean(posterior_samples["beta"], 0) assert_allclose(posterior_mean, true_coefs, atol=0.05) if auto_class not in [AutoDAIS, AutoDelta, AutoIAFNormal, AutoBNAFNormal]: quantiles = guide.quantiles(params, [0.2, 0.5, 0.8]) assert quantiles["beta"].shape == (3, 2) # Predictive can be instantiated from posterior samples... predictive = Predictive(model, posterior_samples=posterior_samples) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, N, 2) # ... or from the guide + params predictive = Predictive(model, guide=guide, params=params, num_samples=1000) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, N, 2)
def test_collapse_beta_binomial_plate(): data = np.array([0., 1., 5., 5.]) def model(): c = numpyro.sample("c", dist.Gamma(1, 1)) with handlers.collapse(): probs = numpyro.sample("probs", dist.Beta(c, 2)) with numpyro.plate("plate", len(data)): numpyro.sample("obs", dist.Binomial(10, probs), obs=data) def guide(): a = numpyro.param("a", 1., constraint=constraints.positive) b = numpyro.param("b", 1., constraint=constraints.positive) numpyro.sample("c", dist.Gamma(a, b)) svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0)) svi.update(svi_state)
def test_autoguide(deterministic): GLOBAL["count"] = 0 guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), deterministic=deterministic) svi_state = svi.init(random.PRNGKey(0)) svi_state = lax.fori_loop(0, 100, lambda i, val: svi.update(val)[0], svi_state) params = svi.get_params(svi_state) guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(100, )) if deterministic: assert GLOBAL["count"] == 5 else: assert GLOBAL["count"] == 4
def test_run(progress_bar): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", lambda key: random.normal(key), constraint=constraints.positive) beta_q = numpyro.param("beta_q", lambda key: random.exponential(key), constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.05), Trace_ELBO()) params, losses = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar) assert losses.shape == (1000,) assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']), 0.8, atol=0.05, rtol=0.05)
def test_auto_guide(auto_class, init_loc_fn, num_particles): latent_dim = 3 def model(obs): a = numpyro.sample("a", Normal(0, 1)) return numpyro.sample("obs", Bernoulli(logits=a), obs=obs) obs = Bernoulli(0.5).sample(random.PRNGKey(0), (10, latent_dim)) rng_key = random.PRNGKey(0) guide_key, stein_key = random.split(rng_key) inner_guide = auto_class(model, init_loc_fn=init_loc_fn()) with handlers.seed(rng_seed=guide_key), handlers.trace() as inner_guide_tr: inner_guide(obs) steinvi = SteinVI( model, auto_class(model, init_loc_fn=init_loc_fn()), Adam(1.0), Trace_ELBO(), RBFKernel(), num_particles=num_particles, ) state = steinvi.init(stein_key, obs) init_params = steinvi.get_params(state) for name, site in inner_guide_tr.items(): if site.get("type") == "param": assert name in init_params inner_param = site init_value = init_params[name] expected_shape = (num_particles, *np.shape(inner_param["value"])) assert init_value.shape == expected_shape if "auto_loc" in name or name == "b": assert np.alltrue(init_value != np.zeros(expected_shape)) assert np.unique(init_value).shape == init_value.reshape( -1).shape elif "scale" in name: assert_array_approx_equal(init_value, np.full(expected_shape, 0.1)) else: assert_array_approx_equal(init_value, np.full(expected_shape, 0.0))
def test_param(): # this test the validity of model having # param sites contain composed transformed constraints rng_keys = random.split(random.PRNGKey(0), 3) a_minval = 1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) x_init = random.normal(rng_keys[2]) def model(): a = numpyro.param("a", a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param("b", b_init, constraint=constraints.positive) numpyro.sample("x", dist.Normal(a, b)) # this class is used to force init value of `x` to x_init class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {"_auto_latent": x_init[None]})(*args, **kwargs) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = _AutoGuide(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init) params = svi.get_params(svi_state) assert_allclose(params["a"], a_init, rtol=1e-6) assert_allclose(params["b"], b_init, rtol=1e-6) assert_allclose(params["auto_loc"], guide._init_latent, rtol=1e-6) assert_allclose(params["auto_scale"], jnp.ones(1) * guide._init_scale, rtol=1e-6) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal( guide._init_latent, guide._init_scale).log_prob(x_init) - dist.Normal( a_init, b_init).log_prob(x_init) assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_get_params(kernel, auto_guide, init_loc_fn, problem): _, data, model = problem() guide, optim, elbo = ( auto_guide(model, init_loc_fn=init_loc_fn), Adam(1e-1), Trace_ELBO(), ) stein = SteinVI(model, guide, optim, elbo, kernel) stein_params = stein.get_params(stein.init(random.PRNGKey(0), *data)) svi = SVI(model, guide, optim, elbo) svi_params = svi.get_params(svi.init(random.PRNGKey(0), *data)) assert svi_params.keys() == stein_params.keys() for name, svi_param in svi_params.items(): assert (stein_params[name].shape == np.repeat(svi_param[None, ...], stein.num_particles, axis=0).shape)
def test_subsample_guide(auto_class): # The model adapted from tutorial/source/easyguide.ipynb def model(batch, subsample, full_size): drift = numpyro.sample("drift", dist.LogNormal(-1, 0.5)) with handlers.substitute(data={"data": subsample}): plate = numpyro.plate("data", full_size, subsample_size=len(subsample)) assert plate.size == 50 def transition_fn(z_prev, y_curr): with plate: z_curr = numpyro.sample("state", dist.Normal(z_prev, drift)) y_curr = numpyro.sample("obs", dist.Bernoulli(logits=z_curr), obs=y_curr) return z_curr, y_curr _, result = scan(transition_fn, jnp.zeros(len(subsample)), batch, length=num_time_steps) return result def create_plates(batch, subsample, full_size): with handlers.substitute(data={"data": subsample}): return numpyro.plate("data", full_size, subsample_size=subsample.shape[0]) guide = auto_class(model, create_plates=create_plates) full_size = 50 batch_size = 20 num_time_steps = 8 with handlers.seed(rng_seed=0): data = model(None, jnp.arange(full_size), full_size) assert data.shape == (num_time_steps, full_size) svi = SVI(model, guide, optim.Adam(0.02), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), data[:, :batch_size], jnp.arange(batch_size), full_size=full_size) update_fn = jit(svi.update, static_argnums=(3,)) for epoch in range(2): beg = 0 while beg < full_size: end = min(full_size, beg + batch_size) subsample = jnp.arange(beg, end) batch = data[:, beg:end] beg = end svi_state, loss = update_fn(svi_state, batch, subsample, full_size)
def train_model_no_dp(rng, model, guide, data, batch_size, num_data, num_epochs, silent=False, **kwargs): """ trains a given model using SVI (no DP!) and the globally defined parameters and data """ optimizer = Adam(1e-3) svi = SVI(model, guide, optimizer, Trace_ELBO(), num_obs_total=num_data) import d3p.random.debug return _train_model(d3p.random.convert_to_jax_rng_key(rng), d3p.random.debug, svi, data, batch_size, num_data, num_epochs, silent)
def test_param(): # this test the validity of model/guide sites having # param constraints contain composed transformed rng_keys = random.split(random.PRNGKey(0), 5) a_minval = 1 c_minval = -2 c_maxval = -1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval) d_init = random.uniform(rng_keys[3]) obs = random.normal(rng_keys[4]) def model(): a = numpyro.param("a", a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param("b", b_init, constraint=constraints.positive) numpyro.sample("x", dist.Normal(a, b), obs=obs) def guide(): c = numpyro.param("c", c_init, constraint=constraints.interval(c_minval, c_maxval)) d = numpyro.param("d", d_init, constraint=constraints.unit_interval) numpyro.sample("y", dist.Normal(c, d), obs=obs) adam = optim.Adam(0.01) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) assert_allclose(params["a"], a_init) assert_allclose(params["b"], b_init) assert_allclose(params["c"], c_init) assert_allclose(params["d"], d_init) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal( a_init, b_init).log_prob(obs) # not so precisely because we do transform / inverse transform stuffs assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_jitted_update_fn(): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optim.Adam(0.05) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) expected = svi.get_params(svi.update(svi_state, data)[0]) actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0]) check_close(actual, expected, atol=1e-5)
def test_calc_particle_info(num_params, num_particles): seed = random.PRNGKey(nrandom.randint(0, 10_000)) sizes = Poisson(5).sample(seed, (100, nrandom.randint(0, 10))) + 1 uparam = tuple(np.empty(tuple(size)) for size in sizes) uparams = {string.ascii_lowercase[i]: uparam for i in range(num_params)} par_param_size = sum(map(lambda size: size.prod(), sizes)) // num_particles expected_start_end = zip( par_param_size * np.arange(num_params), par_param_size * np.arange(1, num_params + 1), ) expected_pinfo = dict( zip(string.ascii_lowercase[:num_params], expected_start_end)) stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel()) pinfo = stein._calc_particle_info(uparams, num_particles) for k in pinfo.keys(): assert pinfo[k] == expected_pinfo[k], f"Failed for seed {seed}"
def test_svi_discrete_latent(): cont_inf_only_cls = [RenyiELBO(), Trace_ELBO(), TraceMeanField_ELBO()] mixed_inf_cls = [TraceGraph_ELBO()] assert not any([c.can_infer_discrete for c in cont_inf_only_cls]) assert all([c.can_infer_discrete for c in mixed_inf_cls]) def model(): numpyro.sample("x", dist.Bernoulli(0.5)) def guide(): probs = numpyro.param("probs", 0.2) numpyro.sample("x", dist.Bernoulli(probs)) for elbo in cont_inf_only_cls: svi = SVI(model, guide, optim.Adam(1), elbo) s_name = type(elbo).__name__ w_msg = f"Currently, SVI with {s_name} loss does not support models with discrete latent variables" with pytest.warns(UserWarning, match=w_msg): svi.run(random.PRNGKey(0), 10)
def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)]) x_guide = dist.Uniform(0, 3) def model(): numpyro.sample('x', x_prior) def guide(): numpyro.sample('x', x_guide) adam = optim.Adam(0.01) x = 2. guide = substitute(guide, data={'x': x}) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x) assert_allclose(actual_loss, expected_loss)
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel): svi_key, mcmc_key = random.split(hmcecs_key) # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy) optimizer = numpyro.optim.Adam(step_size=1e-3) guide = autoguide.AutoDelta(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) params, losses = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size) ref_params = {'theta': params['theta_auto_loc']} # taylor proxy estimates log likelihood (ll) by # taylor_expansion(ll, theta_curr) + # sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params proxy = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy) mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(mcmc_key, data, obs, args.subsample_size) mcmc.print_summary() return losses, mcmc.get_samples()
def test_autoguide_deterministic(auto_class): def model(y=None): n = y.size if y is not None else 1 mu = numpyro.sample("mu", dist.Normal(0, 5)) sigma = numpyro.param("sigma", 1, constraint=constraints.positive) with numpyro.plate("N", len(y)): y = numpyro.sample("y", dist.Normal(mu, sigma).expand((n, )), obs=y) numpyro.deterministic("z", (y - mu) / sigma) mu, sigma = 2, 3 y = mu + sigma * random.normal(random.PRNGKey(0), shape=(300, )) y_train = y[:200] y_test = y[200:] guide = auto_class(model) optimiser = numpyro.optim.Adam(step_size=0.01) svi = SVI(model, guide, optimiser, Trace_ELBO()) svi_result = svi.run(random.PRNGKey(0), num_steps=500, y=y_train) params = svi_result.params posterior_samples = guide.sample_posterior(random.PRNGKey(0), params, sample_shape=(1000, )) predictive = Predictive(model, posterior_samples, params=params) predictive_samples = predictive(random.PRNGKey(0), y_test) assert predictive_samples["y"].shape == (1000, 100) assert predictive_samples["z"].shape == (1000, 100) assert_allclose( (predictive_samples["y"] - posterior_samples["mu"][..., None]) / params["sigma"], predictive_samples["z"], atol=0.05, )
def test_neals_funnel_smoke(): dim = 10 guide = AutoIAFNormal(neals_funnel) svi = SVI(neals_funnel, guide, Adam(1e-10), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), dim) def body_fn(i, val): svi_state, loss = svi.update(val, dim) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) model = neutra.reparam(neals_funnel) nuts = NUTS(model) mcmc = MCMC(nuts, num_warmup=50, num_samples=50) mcmc.run(random.PRNGKey(1), dim) samples = mcmc.get_samples() transformed_samples = neutra.transform_sample(samples['auto_shared_latent']) assert 'x' in transformed_samples assert 'y' in transformed_samples
def test_beta_bernoulli(auto_class): data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2))) numpyro.sample('obs', dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,)) assert_allclose(jnp.mean(posterior_samples['beta'], 0), true_coefs, atol=0.05)
@pytest.mark.parametrize("optim_class, args, kwargs", optimizers) def test_optim_multi_params(optim_class, args, kwargs): params = { "x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([-1, -1.0, -1.0]) } opt = optax_to_numpyro(optim_class(*args, **kwargs)) opt_state = opt.init(params) for i in range(2000): opt_state = step(opt_state, opt) for _, param in opt.get_params(opt_state).items(): assert jnp.allclose(param, jnp.zeros(3)) @pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)]) def test_beta_bernoulli(elbo): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optax.adam(0.05)
def main(args): encoder_nn = encoder(args.hidden_dim, args.z_dim) decoder_nn = decoder(args.hidden_dim, 28 * 28) adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, Trace_ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim) rng_key = PRNGKey(0) train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train') test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test') num_train, train_idx = train_init() rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3) sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0]) svi_state = svi.init(rng_key_init, sample_batch) @jit def epoch_train(svi_state, rng_key): def body_fn(i, val): loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state return lax.fori_loop(0, num_train, body_fn, (0., svi_state)) @jit def eval_test(svi_state, rng_key): def body_fun(i, loss_sum): rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0]) # FIXME: does this lead to a requirement for an rng_key arg in svi_eval? loss = svi.evaluate(svi_state, batch) / len(batch) loss_sum += loss return loss_sum loss = lax.fori_loop(0, num_test, body_fun, 0.) loss = loss / num_test return loss def reconstruct_img(epoch, rng_key): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') rng_key_binarize, rng_key_sample = random.split(rng_key) test_sample = binarize(rng_key_binarize, img) params = svi.get_params(svi_state) z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1])) z = dist.Normal(z_mean, z_var).sample(rng_key_sample) img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray') for i in range(args.num_epochs): rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split( rng_key, 4) t_start = time.time() num_train, train_idx = train_init() _, svi_state = epoch_train(svi_state, rng_key_train) rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3) num_test, test_idx = test_init() test_loss = eval_test(svi_state, rng_key_test) reconstruct_img(i, rng_key_reconstruct) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
def test_cond(): def model(): def true_fun(_): x = numpyro.sample("x", dist.Normal(4.0)) numpyro.deterministic("z", x - 4.0) def false_fun(_): x = numpyro.sample("x", dist.Normal(0.0)) numpyro.deterministic("z", x) cluster = numpyro.sample("cluster", dist.Normal()) cond(cluster > 0, true_fun, false_fun, None) def guide(): m1 = numpyro.param("m1", 2.0) s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) m2 = numpyro.param("m2", 2.0) s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) def true_fun(_): numpyro.sample("x", dist.Normal(m1, s1)) def false_fun(_): numpyro.sample("x", dist.Normal(m2, s2)) cluster = numpyro.sample("cluster", dist.Normal()) cond(cluster > 0, true_fun, false_fun, None) svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) svi_result = svi.run(random.PRNGKey(0), num_steps=2500) params = svi_result.params predictive = Predictive( model, guide=guide, params=params, num_samples=1000, return_sites=["cluster", "x", "z"], ) result = predictive(random.PRNGKey(0)) assert result["cluster"].shape == (1000,) assert result["x"].shape == (1000,) assert result["z"].shape == (1000,) mcmc = MCMC( NUTS(model), num_warmup=500, num_samples=2500, num_chains=4, chain_method="sequential", ) mcmc.run(random.PRNGKey(0)) x = mcmc.get_samples()["x"] assert x.shape == (10_000,) assert_allclose( [x[x > 2.0].mean(), x[x > 2.0].std(), x[x < 2.0].mean(), x[x < 2.0].std()], [4.01, 0.965, -0.01, 0.965], atol=0.1, ) assert_allclose([x.mean(), x.std()], [2.0, jnp.sqrt(5.0)], atol=0.5)
def benchmark_hmc(args, features, labels): rng_key = random.PRNGKey(1) start = time.time() # a MAP estimate at the following source # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 ref_params = { "coefs": jnp.array([ +2.03420663e00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, -1.59496680e-01, -1.88516974e-01, -1.20889175e00, ]) } if args.algo == "HMC": step_size = jnp.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps kernel = HMC( model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, dense_mass=args.dense_mass, ) subsample_size = None elif args.algo == "NUTS": kernel = NUTS(model, dense_mass=args.dense_mass) subsample_size = None elif args.algo == "HMCECS": subsample_size = 1000 inner_kernel = NUTS( model, init_strategy=init_to_value(values=ref_params), dense_mass=args.dense_mass, ) # note: if num_blocks=100, we'll update 10 index at each MCMC step # so it took 50000 MCMC steps to iterative the whole dataset kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)) elif args.algo == "SA": # NB: this kernel requires large num_warmup and num_samples # and running on GPU is much faster than on CPU kernel = SA(model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)) subsample_size = None elif args.algo == "FlowHMCECS": subsample_size = 1000 guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(2), 2000, features, labels) params, losses = svi_result.params, svi_result.losses plt.plot(losses) plt.show() neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(model) neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)} # no need to adapt mass matrix if the flow does a good job inner_kernel = NUTS( neutra_model, init_strategy=init_to_value(values=neutra_ref_params), adapt_mass_matrix=False, ) kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)) else: raise ValueError( "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob", )) print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) mcmc.print_summary(exclude_deterministic=False) print("\nMCMC elapsed time:", time.time() - start)