def test_dynamic_supports(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000,)) def actual_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) def expected_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(actual_model) svi = SVI(actual_model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data) actual_opt_params = adam.get_params(svi_state.optim_state) actual_params = svi.get_params(svi_state) actual_values = guide.median(actual_params) actual_loss = svi.evaluate(svi_state, data) guide = AutoDiagonalNormal(expected_model) svi = SVI(expected_model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data) expected_opt_params = adam.get_params(svi_state.optim_state) expected_params = svi.get_params(svi_state) expected_values = guide.median(expected_params) expected_loss = svi.evaluate(svi_state, data) # test auto_loc, auto_scale check_eq(actual_opt_params, expected_opt_params) check_eq(actual_params, expected_params) # test latent values assert_allclose(actual_values['alpha'], expected_values['alpha']) assert_allclose(actual_values['loc_base'], expected_values['loc']) assert_allclose(actual_loss, expected_loss)
def test_subsample_guide(auto_class): # The model adapted from tutorial/source/easyguide.ipynb def model(batch, subsample, full_size): drift = numpyro.sample("drift", dist.LogNormal(-1, 0.5)) with handlers.substitute(data={"data": subsample}): plate = numpyro.plate("data", full_size, subsample_size=len(subsample)) assert plate.size == 50 def transition_fn(z_prev, y_curr): with plate: z_curr = numpyro.sample("state", dist.Normal(z_prev, drift)) y_curr = numpyro.sample("obs", dist.Bernoulli(logits=z_curr), obs=y_curr) return z_curr, y_curr _, result = scan(transition_fn, jnp.zeros(len(subsample)), batch, length=num_time_steps) return result def create_plates(batch, subsample, full_size): with handlers.substitute(data={"data": subsample}): return numpyro.plate("data", full_size, subsample_size=subsample.shape[0]) guide = auto_class(model, create_plates=create_plates) full_size = 50 batch_size = 20 num_time_steps = 8 with handlers.seed(rng_seed=0): data = model(None, jnp.arange(full_size), full_size) assert data.shape == (num_time_steps, full_size) svi = SVI(model, guide, optim.Adam(0.02), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), data[:, :batch_size], jnp.arange(batch_size), full_size=full_size) update_fn = jit(svi.update, static_argnums=(3, )) for epoch in range(2): beg = 0 while beg < full_size: end = min(full_size, beg + batch_size) subsample = jnp.arange(beg, end) batch = data[:, beg:end] beg = end svi_state, loss = update_fn(svi_state, batch, subsample, full_size)
def test_tracegraph_normal_normal(): # normal-normal; known covariance lam0 = jnp.array([0.1, 0.1]) # precision of prior loc0 = jnp.array([0.0, 0.5]) # prior mean # known precision of observation noise lam = jnp.array([6.0, 4.0]) data = [] data.append(jnp.array([-0.1, 0.3])) data.append(jnp.array([0.0, 0.4])) data.append(jnp.array([0.2, 0.5])) data.append(jnp.array([0.1, 0.7])) n_data = len(data) sum_data = data[0] + data[1] + data[2] + data[3] analytic_lam_n = lam0 + n_data * lam analytic_log_sig_n = -0.5 * jnp.log(analytic_lam_n) analytic_loc_n = sum_data * (lam / analytic_lam_n) + loc0 * ( lam0 / analytic_lam_n) class FakeNormal(dist.Normal): reparametrized_params = [] def model(): with numpyro.plate("plate", 2): loc_latent = numpyro.sample( "loc_latent", FakeNormal(loc0, jnp.power(lam0, -0.5))) for i, x in enumerate(data): numpyro.sample( "obs_{}".format(i), dist.Normal(loc_latent, jnp.power(lam, -0.5)), obs=x, ) return loc_latent def guide(): loc_q = numpyro.param("loc_q", analytic_loc_n + jnp.array([0.334, 0.334])) log_sig_q = numpyro.param( "log_sig_q", analytic_log_sig_n + jnp.array([-0.29, -0.29])) sig_q = jnp.exp(log_sig_q) with numpyro.plate("plate", 2): loc_latent = numpyro.sample("loc_latent", FakeNormal(loc_q, sig_q)) return loc_latent adam = optim.Adam(step_size=0.0015, b1=0.97, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), 5000) loc_error = jnp.sum( jnp.power(analytic_loc_n - svi_result.params["loc_q"], 2.0)) log_sig_error = jnp.sum( jnp.power(analytic_log_sig_n - svi_result.params["log_sig_q"], 2.0)) assert_allclose(loc_error, 0, atol=0.05) assert_allclose(log_sig_error, 0, atol=0.05)
def test_neutra_reparam_unobserved_model(): model = dirichlet_categorical data = jnp.ones(10, dtype=jnp.int32) guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-3), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), data) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) with handlers.seed(rng_seed=0): reparam_model(data=None)
def test_plate_inconsistent(size, dim): def model(): with numpyro.plate("a", 10, dim=-1): numpyro.sample("x", dist.Normal(0, 1)) with numpyro.plate("a", size, dim=dim): numpyro.sample("y", dist.Normal(0, 1)) guide = AutoDelta(model) svi = SVI(model, guide, numpyro.optim.Adam(step_size=0.1), Trace_ELBO()) with pytest.raises(AssertionError, match="has inconsistent dim or size"): svi.run(random.PRNGKey(0), 10)
def test_subsample_model_with_deterministic(): def model(): x = numpyro.sample("x", dist.Normal(0, 1)) numpyro.deterministic("x2", x * 2) with numpyro.plate("N", 10, subsample_size=5): numpyro.sample("obs", dist.Normal(x, 1), obs=jnp.ones(5)) guide = AutoNormal(model) svi = SVI(model, guide, optim.Adam(1.0), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(0), 10) samples = guide.sample_posterior(random.PRNGKey(1), svi_result.params) assert "x2" in samples
def test_svi_discrete_latent(): def model(): numpyro.sample("x", dist.Bernoulli(0.5)) def guide(): probs = numpyro.param("probs", 0.2) numpyro.sample("x", dist.Bernoulli(probs)) svi = SVI(model, guide, optim.Adam(1), Trace_ELBO()) with pytest.warns(UserWarning, match="SVI does not support models with discrete"): svi.run(random.PRNGKey(0), 10)
def run_svi(model, guide_family, args, X, Y): if guide_family == "AutoDelta": guide = autoguide.AutoDelta(model) elif guide_family == "AutoDiagonalNormal": guide = autoguide.AutoDiagonalNormal(model) optimizer = numpyro.optim.Adam(0.001) svi = SVI(model, guide, optimizer, Trace_ELBO()) svi_results = svi.run(PRNGKey(1), args.maxiter, X=X, Y=Y) params = svi_results.params return params, guide
def test_stable_run(stable_run): def model(): var = numpyro.sample("var", dist.Exponential(1)) numpyro.sample("obs", dist.Normal(0, jnp.sqrt(var)), obs=0.0) def guide(): loc = numpyro.param("loc", 0.0) numpyro.sample("var", dist.Normal(loc, 10)) svi = SVI(model, guide, optim.Adam(1), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_run) assert jnp.isfinite(svi_result.params["loc"]) == stable_run
def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) offset = numpyro.sample("offset", dist.Uniform(-1, 1)) logits = offset + jnp.sum(coefs * data, axis=-1) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1,)) rng_key = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng_key, params) actual_output = guide._unpack_latent(guide.get_transform(params)(x)) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(transforms.PermuteTransform(jnp.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN( dim + 1, [dim + 1, dim + 1], permutation=jnp.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity, ) arn = partial(arn_apply, params["auto_arn__{}$params".format(i)]) flows.append(InverseAutoregressiveTransform(arn)) flows.append(guide._unpack_latent) transform = transforms.ComposeTransform(flows) _, rng_key_sample = random.split(rng_key) expected_sample = transform( dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample) ) expected_output = transform(x) assert_allclose(actual_sample["coefs"], expected_sample["coefs"]) assert_allclose( actual_sample["offset"], transforms.biject_to(constraints.interval(-1, 1))(expected_sample["offset"]), ) check_eq(actual_output, expected_output)
def svi(model, guide, num_steps, lr, rng_key, X, Y): """ Helper function for doing SVI inference. """ svi = SVI(model, guide, optim.Adam(lr), ELBO(num_particles=1), X=X, Y=Y) svi_state = svi.init(rng_key) print('Optimizing...') state, loss = lax.scan(lambda x, i: svi.update(x), svi_state, np.zeros(num_steps)) return loss, svi.get_params(state)
def test_logistic_regression(auto_class, Elbo): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample("coefs", dist.Normal(0, 1).expand([dim]).to_event()) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) with numpyro.plate("N", len(data)): return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Elbo()) svi_state = svi.init(rng_key_init, data, labels) # smoke test if analytic KL is used if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO: _, mean_field_loss = svi.update(svi_state, data, labels) svi.loss = Trace_ELBO() _, elbo_loss = svi.update(svi_state, data, labels) svi.loss = TraceMeanField_ELBO() assert abs(mean_field_loss - elbo_loss) > 0.5 def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoDAIS, AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median["coefs"], true_coefs, rtol=0.1) # test .quantile method if auto_class is not AutoDelta: median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median["coefs"][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(posterior_samples["coefs"], 0), expected_coefs, rtol=0.1)
def test_beta_bernoulli(auto_class): data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T N = len(data) def model(data): f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2)).to_event()) with numpyro.plate("N", N): numpyro.sample("obs", dist.Bernoulli(f).to_event(1), obs=data) adam = optim.Adam(0.01) if auto_class == AutoDAIS: guide = auto_class(model, init_loc_fn=init_strategy, base_dist="cholesky") else: guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) posterior_mean = jnp.mean(posterior_samples["beta"], 0) assert_allclose(posterior_mean, true_coefs, atol=0.05) if auto_class not in [AutoDAIS, AutoDelta, AutoIAFNormal, AutoBNAFNormal]: quantiles = guide.quantiles(params, [0.2, 0.5, 0.8]) assert quantiles["beta"].shape == (3, 2) # Predictive can be instantiated from posterior samples... predictive = Predictive(model, posterior_samples=posterior_samples) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, N, 2) # ... or from the guide + params predictive = Predictive(model, guide=guide, params=params, num_samples=1000) predictive_samples = predictive(random.PRNGKey(1), None) assert predictive_samples["obs"].shape == (1000, N, 2)
def test_autodais_subsampling_error(): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1, 1)) with numpyro.plate("plate", 20, 10, dim=-1): numpyro.sample("obs", dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = AutoDAIS(model) svi = SVI(model, guide, adam, Trace_ELBO()) with pytest.raises(NotImplementedError, match=".*data subsampling.*"): svi.init(random.PRNGKey(1), data)
def test_laplace_approximation_custom_hessian(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10)) mu = a + b * x numpyro.sample("y", dist.Normal(mu, 1), obs=y) x = random.normal(random.PRNGKey(0), (100, )) y = 1 + 2 * x guide = AutoLaplaceApproximation( model, hessian_fn=lambda f, x: jacobian(jacobian(f))(x)) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) svi_result = svi.run(random.PRNGKey(0), 10000, progress_bar=False) guide.get_transform(svi_result.params)
def init_svi(self, X: DeviceArray, *, lr: float, **kwargs): """Initialize the SVI state Args: X: input data lr: learning rate kwargs: other keyword arguments for optimizer """ self.optim = self.optim_builder(lr, **kwargs) self.svi = SVI(self.model, self.guide, self.optim, self.loss) svi_state = self.svi.init(self.rng_key, X) if self.svi_state is None: self.svi_state = svi_state return self
def test_improper(): y = random.normal(random.PRNGKey(0), (100,)) def model(y): lambda1 = numpyro.sample('lambda1', dist.ImproperUniform(dist.constraints.real, (), ())) lambda2 = numpyro.sample('lambda2', dist.ImproperUniform(dist.constraints.real, (), ())) sigma = numpyro.sample('sigma', dist.ImproperUniform(dist.constraints.positive, (), ())) mu = numpyro.deterministic('mu', lambda1 + lambda2) numpyro.sample('y', dist.Normal(mu, sigma), obs=y) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(0.003), Trace_ELBO(), y=y) svi_state = svi.init(random.PRNGKey(2)) lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(10000))
def test_module(): x = random.normal(random.PRNGKey(0), (100, 10)) y = random.normal(random.PRNGKey(1), (100,)) def model(x, y): nn = numpyro.module("nn", Dense(1), (10,)) mu = nn(x).squeeze(-1) sigma = numpyro.sample("sigma", dist.HalfNormal(1)) numpyro.sample("y", dist.Normal(mu, sigma), obs=y) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(0.003), Trace_ELBO(), x=x, y=y) svi_state = svi.init(random.PRNGKey(2)) lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(1000))
def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-10), Trace_ELBO(), **kwargs) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model(random.PRNGKey(2), reparam_model, model_kwargs=kwargs) latent_x = list(init_params[0].values())[0] pe_transformed = pe_fn_neutra(init_params[0]) latent_y = neutra.transform(latent_x) log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn(guide._unpack_latent(latent_y)) assert_allclose(pe_transformed, pe - log_det_jacobian)
def test_collapse_beta_binomial(): total_count = 10 data = 3. def model1(): c1 = numpyro.param("c1", 0.5, constraint=dist.constraints.positive) c0 = numpyro.param("c0", 1.5, constraint=dist.constraints.positive) with handlers.collapse(): probs = numpyro.sample("probs", dist.Beta(c1, c0)) numpyro.sample("obs", dist.Binomial(total_count, probs), obs=data) def model2(): c1 = numpyro.param("c1", 0.5, constraint=dist.constraints.positive) c0 = numpyro.param("c0", 1.5, constraint=dist.constraints.positive) numpyro.sample("obs", dist.BetaBinomial(c1, c0, total_count), obs=data) trace1 = handlers.trace(model1).get_trace() trace2 = handlers.trace(model2).get_trace() assert "probs" in trace1 assert "obs" not in trace1 assert "probs" not in trace2 assert "obs" in trace2 svi1 = SVI(model1, lambda: None, numpyro.optim.Adam(1), Trace_ELBO()) svi2 = SVI(model2, lambda: None, numpyro.optim.Adam(1), Trace_ELBO()) svi_state1 = svi1.init(random.PRNGKey(0)) svi_state2 = svi2.init(random.PRNGKey(0)) params1 = svi1.get_params(svi_state1) params2 = svi2.get_params(svi_state2) assert_allclose(params1["c1"], params2["c1"]) assert_allclose(params1["c0"], params2["c0"]) params1 = svi1.get_params(svi1.update(svi_state1)[0]) params2 = svi2.get_params(svi2.update(svi_state2)[0]) assert_allclose(params1["c1"], params2["c1"]) assert_allclose(params1["c0"], params2["c0"])
def test_laplace_approximation_warning(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,)) mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3 numpyro.sample("y", dist.Normal(mu, 0.001), obs=y) x = random.normal(random.PRNGKey(0), (3,)) y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3 guide = AutoLaplaceApproximation(model) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y) init_state = svi.init(random.PRNGKey(0)) svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state) params = svi.get_params(svi_state) with pytest.warns(UserWarning, match="Hessian of log posterior"): guide.sample_posterior(random.PRNGKey(1), params)
def test_logistic_regression(auto_class, Elbo): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Elbo()) svi_state = svi.init(rng_key_init, data, labels) # smoke test if analytic KL is used if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO: _, mean_field_loss = svi.update(svi_state, data, labels) svi.loss = Trace_ELBO() _, elbo_loss = svi.update(svi_state, data, labels) svi.loss = TraceMeanField_ELBO() assert abs(mean_field_loss - elbo_loss) > 0.5 def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def run_svi_inference(model, guide, rng_key, X, Y, optimizer, n_epochs=1_000): # initialize svi svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) # initialize state init_state = svi.init(rng_key, X, Y.squeeze()) # Run optimizer for 1000 iteratons. state, losses = jax.lax.scan( lambda state, i: svi.update(state, X, Y.squeeze()), init_state, n_epochs ) # Extract surrogate posterior. params = svi.get_params(state) return params
def test_collapse_beta_bernoulli(): data = 0. def model(): c = numpyro.sample("c", dist.Gamma(1, 1)) with handlers.collapse(): probs = numpyro.sample("probs", dist.Beta(c, 2)) numpyro.sample("obs", dist.Bernoulli(probs), obs=data) def guide(): a = numpyro.param("a", 1., constraint=constraints.positive) b = numpyro.param("b", 1., constraint=constraints.positive) numpyro.sample("c", dist.Gamma(a, b)) svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0)) svi.update(svi_state)
def main(_args): data = generate_data() init_rng_key = PRNGKey(1273) # nuts = NUTS(gmm) # mcmc = MCMC(nuts, 100, 1000) # mcmc.print_summary() seeded_gmm = seed(gmm, init_rng_key) model_trace = trace(seeded_gmm).get_trace(data) max_plate_nesting = _guess_max_plate_nesting(model_trace) enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1) svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.)) svi_state = svi.init(init_rng_key, data) upd_fun = jax.jit(svi.update) with tqdm.trange(100_000) as pbar: for i in pbar: svi_state, loss = upd_fun(svi_state, data) pbar.set_description(f"SVI {loss}", True)
def test_pickle_autoguide(guide_class): x = np.random.poisson(1.0, size=(100,)) guide = guide_class(poisson_regression) optim = numpyro.optim.Adam(1e-2) svi = SVI(poisson_regression, guide, optim, numpyro.infer.Trace_ELBO()) svi_result = svi.run(random.PRNGKey(1), 3, x, len(x)) pickled_guide = pickle.loads(pickle.dumps(guide)) predictive = Predictive( poisson_regression, guide=pickled_guide, params=svi_result.params, num_samples=1, return_sites=["param", "x"], ) samples = predictive(random.PRNGKey(1), None, 1) assert set(samples.keys()) == {"param", "x"}
def test_autoguide(deterministic): GLOBAL["count"] = 0 guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), deterministic=deterministic) svi_state = svi.init(random.PRNGKey(0)) svi_state = lax.fori_loop(0, 100, lambda i, val: svi.update(val)[0], svi_state) params = svi.get_params(svi_state) guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(100, )) if deterministic: assert GLOBAL["count"] == 5 else: assert GLOBAL["count"] == 4
def test_collapse_beta_binomial_plate(): data = np.array([0., 1., 5., 5.]) def model(): c = numpyro.sample("c", dist.Gamma(1, 1)) with handlers.collapse(): probs = numpyro.sample("probs", dist.Beta(c, 2)) with numpyro.plate("plate", len(data)): numpyro.sample("obs", dist.Binomial(10, probs), obs=data) def guide(): a = numpyro.param("a", 1., constraint=constraints.positive) b = numpyro.param("b", 1., constraint=constraints.positive) numpyro.sample("c", dist.Gamma(a, b)) svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0)) svi.update(svi_state)
def test_run(progress_bar): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", lambda key: random.normal(key), constraint=constraints.positive) beta_q = numpyro.param("beta_q", lambda key: random.exponential(key), constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.05), Trace_ELBO()) params, losses = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar) assert losses.shape == (1000,) assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']), 0.8, atol=0.05, rtol=0.05)
def fit_advi(model, num_iter, learning_rate=0.01, seed=0): """Automatic Differentiation Variational Inference using a Normal variational distribution with a diagonal covariance matrix. """ rng_key = random.PRNGKey(seed) adam = Adam(learning_rate) # Automatically create a variational distribution (aka "guide" in Pyro's terminology) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key) # Run optimization last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(num_iter)) results = ADVIResults(svi=svi, guide=guide, state=last_state, losses=losses) return results