def test_logistic_regression(auto_class, Elbo): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample("coefs", dist.Normal(0, 1).expand([dim]).to_event()) logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) with numpyro.plate("N", len(data)): return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Elbo()) svi_state = svi.init(rng_key_init, data, labels) # smoke test if analytic KL is used if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO: _, mean_field_loss = svi.update(svi_state, data, labels) svi.loss = Trace_ELBO() _, elbo_loss = svi.update(svi_state, data, labels) svi.loss = TraceMeanField_ELBO() assert abs(mean_field_loss - elbo_loss) > 0.5 def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoDAIS, AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median["coefs"], true_coefs, rtol=0.1) # test .quantile method if auto_class is not AutoDelta: median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median["coefs"][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(posterior_samples["coefs"], 0), expected_coefs, rtol=0.1)
def test_logistic_regression(auto_class, Elbo): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Elbo()) svi_state = svi.init(rng_key_init, data, labels) # smoke test if analytic KL is used if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO: _, mean_field_loss = svi.update(svi_state, data, labels) svi.loss = Trace_ELBO() _, elbo_loss = svi.update(svi_state, data, labels) svi.loss = TraceMeanField_ELBO() assert abs(mean_field_loss - elbo_loss) > 0.5 def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def test_svi_discrete_latent(): cont_inf_only_cls = [RenyiELBO(), Trace_ELBO(), TraceMeanField_ELBO()] mixed_inf_cls = [TraceGraph_ELBO()] assert not any([c.can_infer_discrete for c in cont_inf_only_cls]) assert all([c.can_infer_discrete for c in mixed_inf_cls]) def model(): numpyro.sample("x", dist.Bernoulli(0.5)) def guide(): probs = numpyro.param("probs", 0.2) numpyro.sample("x", dist.Bernoulli(probs)) for elbo in cont_inf_only_cls: svi = SVI(model, guide, optim.Adam(1), elbo) s_name = type(elbo).__name__ w_msg = f"Currently, SVI with {s_name} loss does not support models with discrete latent variables" with pytest.warns(UserWarning, match=w_msg): svi.run(random.PRNGKey(0), 10)
def run_inference(docs, args): rng_key = random.PRNGKey(0) docs = device_put(docs) hyperparams = dict( vocab_size=docs.shape[1], num_topics=args.num_topics, hidden=args.hidden, dropout_rate=args.dropout_rate, batch_size=args.batch_size, ) optimizer = numpyro.optim.Adam(args.learning_rate) svi = SVI(model, guide, optimizer, loss=TraceMeanField_ELBO()) return svi.run( rng_key, args.num_steps, docs, hyperparams, is_training=True, progress_bar=not args.disable_progbar, nn_framework=args.nn_framework, )