def test_beta_bernoulli(auto_class): data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2))) numpyro.sample('obs', dist.Bernoulli(f), obs=data) adam = optim.Adam(0.01) guide = auto_class(model, init_strategy=init_strategy) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 3000, body_fn, svi_state) params = svi.get_params(svi_state) true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000, )) assert_allclose(jnp.mean(posterior_samples['beta'], 0), true_coefs, atol=0.05)
def __init__(self, model: Model, guide: Guide, rng_key: int = 0, *, loss: ELBO = ELBO(num_particles=1), optim_builder: optim.optimizers.optimizer = optim.Adam): """Handling the model and guide for training and prediction Args: model: function holding the numpyro model guide: function holding the numpyro guide rng_key: random key as int loss: loss to optimize optim_builder: builder for an optimizer """ self.model = model self.guide = guide self.rng_key = random.PRNGKey(rng_key) # current random key self.loss = loss self.optim_builder = optim_builder self.svi = None self.svi_state = None self.optim = None self.log_func = print # overwrite e.g. logger.info(...)
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data) def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) median = guide.median(params) assert_allclose(median['loc'], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['loc'][1], true_coef, rtol=0.1)
def main(args): encoder_nn = encoder(args.hidden_dim, args.z_dim) decoder_nn = decoder(args.hidden_dim, 28 * 28) adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim) rng_key = PRNGKey(0) train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train') test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test') num_train, train_idx = train_init() rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3) sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0]) svi_state = svi.init(rng_key_init, sample_batch) @jit def epoch_train(svi_state, rng_key): def body_fn(i, val): loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state return lax.fori_loop(0, num_train, body_fn, (0., svi_state)) @jit def eval_test(svi_state, rng_key): def body_fun(i, loss_sum): rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0]) # FIXME: does this lead to a requirement for an rng_key arg in svi_eval? loss = svi.evaluate(svi_state, batch) / len(batch) loss_sum += loss return loss_sum loss = lax.fori_loop(0, num_test, body_fun, 0.) loss = loss / num_test return loss def reconstruct_img(epoch, rng_key): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') rng_key_binarize, rng_key_sample = random.split(rng_key) test_sample = binarize(rng_key_binarize, img) params = svi.get_params(svi_state) z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1])) z = dist.Normal(z_mean, z_var).sample(rng_key_sample) img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray') for i in range(args.num_epochs): rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(rng_key, 4) t_start = time.time() num_train, train_idx = train_init() _, svi_state = epoch_train(svi_state, rng_key_train) rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3) num_test, test_idx = test_init() test_loss = eval_test(svi_state, rng_key_test) reconstruct_img(i, rng_key_reconstruct) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
def test_predictive_with_guide(): data = jnp.array([1] * 8 + [0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) with numpyro.plate("plate", 10): numpyro.deterministic("beta_sq", f ** 2) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.1), ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, _ = svi.update(val, data) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) predictive = Predictive(model, guide=guide, params=params, num_samples=1000)(random.PRNGKey(2), data=None) assert predictive["beta_sq"].shape == (1000,) obs_pred = predictive["obs"] assert_allclose(jnp.mean(obs_pred), 0.8, atol=0.05)
def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)]) x_guide = dist.Uniform(0, 3) def model(): numpyro.sample('x', x_prior) def guide(): numpyro.sample('x', x_guide) adam = optim.Adam(0.01) # set base value of x_guide is 0.9 x_base = 0.9 guide = substitute(guide, base_param_map={'x': x_base}) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) x, _ = x_guide.transform_with_intermediates(x_base) expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x) assert_allclose(actual_loss, expected_loss)
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100,)) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, ELBO(num_particles=100)) svi_state = svi.init(PRNGKey(0), data) # Training loop def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, args.num_steps, body_fn, svi_state) # Report the final values of the variational parameters # in the guide after training. params = svi.get_params(svi_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert np.abs(params["guide_loc"] - 3.0) < 0.1
def test_logistic_regression(auto_class): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = auto_class(model, init_strategy=init_strategy) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data, labels) def body_fn(i, val): svi_state, loss = svi.update(val, data, labels) return svi_state svi_state = fori_loop(0, 2000, body_fn, svi_state) params = svi.get_params(svi_state) if auto_class not in (AutoIAFNormal, AutoBNAFNormal): median = guide.median(params) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method median = guide.quantiles(params, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,)) assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def test_dynamic_supports(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def actual_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) def expected_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(actual_model) svi = SVI(actual_model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data) actual_opt_params = adam.get_params(svi_state.optim_state) actual_params = svi.get_params(svi_state) actual_values = guide.median(actual_params) actual_loss = svi.evaluate(svi_state, data) guide = AutoDiagonalNormal(expected_model) svi = SVI(expected_model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data) expected_opt_params = adam.get_params(svi_state.optim_state) expected_params = svi.get_params(svi_state) expected_values = guide.median(expected_params) expected_loss = svi.evaluate(svi_state, data) # test auto_loc, auto_scale check_eq(actual_opt_params, expected_opt_params) check_eq(actual_params, expected_params) # test latent values assert_allclose(actual_values['alpha'], expected_values['alpha']) assert_allclose(actual_values['loc_base'], expected_values['loc']) assert_allclose(actual_loss, expected_loss)
def svi(model, guide, num_steps, lr, rng_key, X, Y): """ Helper function for doing SVI inference. """ svi = SVI(model, guide, optim.Adam(lr), ELBO(num_particles=1), X=X, Y=Y) svi_state = svi.init(rng_key) print('Optimizing...') state, loss = lax.scan(lambda x, i: svi.update(x), svi_state, np.zeros(num_steps)) return loss, svi.get_params(state)
def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) offset = numpyro.sample('offset', dist.Uniform(-1, 1)) logits = offset + jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1, )) rng_key = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng_key, params) actual_output = guide._unpack_latent(guide.get_transform(params)(x)) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(transforms.PermuteTransform( jnp.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN( dim + 1, [dim + 1, dim + 1], permutation=jnp.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity) arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)]) flows.append(InverseAutoregressiveTransform(arn)) flows.append(guide._unpack_latent) transform = transforms.ComposeTransform(flows) _, rng_key_sample = random.split(rng_key) expected_sample = transform( dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample)) expected_output = transform(x) assert_allclose(actual_sample['coefs'], expected_sample['coefs']) assert_allclose( actual_sample['offset'], transforms.biject_to(constraints.interval(-1, 1))( expected_sample['offset'])) check_eq(actual_output, expected_output)
def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-10), ELBO(), **kwargs) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model(random.PRNGKey(2), reparam_model, model_kwargs=kwargs) latent_x = list(init_params[0].values())[0] pe_transformed = pe_fn_neutra(init_params[0]) latent_y = neutra.transform(latent_x) log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn(guide._unpack_latent(latent_y)) assert_allclose(pe_transformed, pe - log_det_jacobian)
def test_laplace_approximation_warning(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,)) mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3 numpyro.sample("y", dist.Normal(mu, 0.001), obs=y) x = random.normal(random.PRNGKey(0), (3,)) y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3 guide = AutoLaplaceApproximation(model) svi = SVI(model, guide, optim.Adam(0.1), ELBO(), x=x, y=y) init_state = svi.init(random.PRNGKey(0)) svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state) params = svi.get_params(svi_state) with pytest.warns(UserWarning, match="Hessian of log posterior"): guide.sample_posterior(random.PRNGKey(1), params)
def test_improper(): y = random.normal(random.PRNGKey(0), (100, )) def model(y): lambda1 = numpyro.sample( 'lambda1', dist.ImproperUniform(dist.constraints.real, (), ())) lambda2 = numpyro.sample( 'lambda2', dist.ImproperUniform(dist.constraints.real, (), ())) sigma = numpyro.sample( 'sigma', dist.ImproperUniform(dist.constraints.positive, (), ())) mu = numpyro.deterministic('mu', lambda1 + lambda2) numpyro.sample('y', dist.Normal(mu, sigma), obs=y) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(0.003), ELBO(), y=y) svi_state = svi.init(random.PRNGKey(2)) lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(10000))
def test_autoguide(deterministic): GLOBAL["count"] = 0 guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(0.1), ELBO(), deterministic=deterministic) svi_state = svi.init(random.PRNGKey(0)) svi_state = lax.fori_loop(0, 100, lambda i, val: svi.update(val)[0], svi_state) params = svi.get_params(svi_state) guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(100, )) if deterministic: assert GLOBAL["count"] == 5 else: assert GLOBAL["count"] == 4
def test_param(): # this test the validity of model/guide sites having # param constraints contain composed transformed rng_keys = random.split(random.PRNGKey(0), 5) a_minval = 1 c_minval = -2 c_maxval = -1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval) d_init = random.uniform(rng_keys[3]) obs = random.normal(rng_keys[4]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b), obs=obs) def guide(): c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval)) d = numpyro.param('d', d_init, constraint=constraints.unit_interval) numpyro.sample('y', dist.Normal(c, d), obs=obs) adam = optim.Adam(0.01) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['c'], c_init) assert_allclose(params['d'], d_init) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal( a_init, b_init).log_prob(obs) # not so precisely because we do transform / inverse transform stuffs assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_jitted_update_fn(): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) adam = optim.Adam(0.05) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(1), data) expected = svi.get_params(svi.update(svi_state, data)[0]) actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0]) check_close(actual, expected, atol=1e-5)
def test_param(): # this test the validity of model having # param sites contain composed transformed constraints rng_keys = random.split(random.PRNGKey(0), 3) a_minval = 1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) x_init = random.normal(rng_keys[2]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b)) # this class is used to force init value of `x` to x_init class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {'_auto_latent': x_init})(*args, **kwargs) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = _AutoGuide(model) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['auto_loc'], guide._init_latent) assert_allclose(params['auto_scale'], jnp.ones(1) * guide._init_scale) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal(guide._init_latent, guide._init_scale).log_prob(x_init) \ - dist.Normal(a_init, b_init).log_prob(x_init) assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)]) x_guide = dist.Uniform(0, 3) def model(): numpyro.sample('x', x_prior) def guide(): numpyro.sample('x', x_guide) adam = optim.Adam(0.01) x = 2. guide = substitute(guide, data={'x': x}) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x) assert_allclose(actual_loss, expected_loss)
def test_neals_funnel_smoke(): dim = 10 guide = AutoIAFNormal(neals_funnel) svi = SVI(neals_funnel, guide, Adam(1e-10), ELBO()) svi_state = svi.init(random.PRNGKey(0), dim) def body_fn(i, val): svi_state, loss = svi.update(val, dim) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) model = neutra.reparam(neals_funnel) nuts = NUTS(model) mcmc = MCMC(nuts, num_warmup=50, num_samples=50) mcmc.run(random.PRNGKey(1), dim) samples = mcmc.get_samples() transformed_samples = neutra.transform_sample(samples['auto_shared_latent']) assert 'x' in transformed_samples assert 'y' in transformed_samples
def main(args): rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( (X_train, ), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test, ), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=args.sigma, clipping_threshold=args.clip_threshold, d=args.dimensions, num_obs_total=args.num_samples) rng, svi_init_rng, batchifier_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=batchifier_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) q = args.batch_size / args.num_samples eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs) print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format( eps, args.sigma, args.clip_threshold, args.delta, q)) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update(svi_state, *batch) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10) == 0): rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches) print( "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start)) params = svi.get_params(svi_state) mu_loc = params['mu_loc'] mu_std = jnp.exp(params['mu_std_log']) print("### expected: {}".format(mu_true)) print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = analytical_solution(X_train) print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = ml_estimate(X_train) print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() guide = AutoBNAFNormal( dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]) svi = SVI(dual_moon_model, guide, optim.Adam(0.003), ELBO()) svi_state = svi.init(random.PRNGKey(1)) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(2), params, sample_shape=(args.num_samples, ))['x'].copy() print("\nStart NeuTra HMC...") neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(dual_moon_model) nuts_kernel = NUTS(neutra_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(3)) mcmc.print_summary() zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"] print("Transform samples into unwarped space...") samples = neutra.transform_sample(zs) print_summary(samples) zs = zs.reshape(-1, 2) samples = samples['x'].reshape(-1, 2).copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(jnp.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = neutra.transform_sample(guide_base_samples)['x'] x1 = jnp.linspace(-3, 3, 100) x2 = jnp.linspace(-3, 3, 100) X1, X2 = jnp.meshgrid(x1, x2) P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 8), constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 0]) ax3 = fig.add_subplot(gs[0, 1]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[0, 2]) ax6 = fig.add_subplot(gs[1, 2]) ax1.plot(losses[1000:]) ax1.set_title('Autoguide training loss\n(after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)' ) ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler') plt.savefig("neutra.pdf")
def main(args): # loading data (train_init, train_fetch_plain), num_samples = load_dataset( MNIST, batch_size=args.batch_size, split='train', batchifier=subsample_batchify_data) (test_init, test_fetch_plain), _ = load_dataset(MNIST, batch_size=args.batch_size, split='test', batchifier=split_batchify_data) def binarize_fetch(fetch_fn): @jit def fetch_binarized(batch_nr, batchifier_state, binarize_rng): batch = fetch_fn(batch_nr, batchifier_state) return binarize(binarize_rng, batch[0]), batch[1] return fetch_binarized train_fetch = binarize_fetch(train_fetch_plain) test_fetch = binarize_fetch(test_fetch_plain) # setting up optimizer optimizer = optimizers.Adam(args.learning_rate) # the minibatch environment in our model scales individual # records' contributions to the loss up by num_samples. # This can cause numerical instabilities so we scale down # the loss by 1/num_samples here. sc_model = scale(model, scale=1 / num_samples) sc_guide = scale(guide, scale=1 / num_samples) if args.no_dp: svi = SVI(sc_model, sc_guide, optimizer, ELBO(), num_obs_total=num_samples, z_dim=args.z_dim, hidden_dim=args.hidden_dim) else: q = args.batch_size / num_samples target_eps = args.epsilon dp_scale, act_eps, _ = approximate_sigma(target_eps=target_eps, delta=1 / num_samples, q=q, num_iter=int(1 / q) * args.num_epochs, force_smaller=True) print( f"using noise scale {dp_scale} for epsilon of {act_eps} (targeted: {target_eps})" ) svi = DPSVI(sc_model, sc_guide, optimizer, ELBO(), dp_scale=dp_scale, clipping_threshold=10., num_obs_total=num_samples, z_dim=args.z_dim, hidden_dim=args.hidden_dim) # preparing random number generators and initializing svi rng = PRNGKey(0) rng, binarize_rng, svi_init_rng, batchifier_rng = random.split(rng, 4) _, batchifier_state = train_init(rng_key=batchifier_rng) sample_batch = train_fetch(0, batchifier_state, binarize_rng)[0] svi_state = svi.init(svi_init_rng, sample_batch) # functions for training tasks @jit def epoch_train(svi_state, batchifier_state, num_batches, rng): """Trains one epoch :param svi_state: current state of the optimizer :param rng: rng key :return: overall training loss over the epoch """ def body_fn(i, val): svi_state, loss = val binarize_rng = random.fold_in(rng, i) batch = train_fetch(i, batchifier_state, binarize_rng)[0] svi_state, batch_loss = svi.update(svi_state, batch) loss += batch_loss / num_batches return svi_state, loss svi_state, loss = lax.fori_loop(0, num_batches, body_fn, (svi_state, 0.)) return svi_state, loss @jit def eval_test(svi_state, batchifier_state, num_batches, rng): """Evaluates current model state on test data. :param svi_state: current state of the optimizer :param rng: rng key :return: loss over the test split """ def body_fn(i, loss_sum): binarize_rng = random.fold_in(rng, i) batch = test_fetch(i, batchifier_state, binarize_rng)[0] batch_loss = svi.evaluate(svi_state, batch) loss_sum += batch_loss / num_batches return loss_sum return lax.fori_loop(0, num_batches, body_fn, 0.) def reconstruct_img(epoch, num_epochs, batchifier_state, svi_state, rng): """Reconstructs an image for the given epoch Obtains a sample from the testing data set and passes it through the VAE. Stores the result as image file 'epoch_{epoch}_recons.png' and the original input as 'epoch_{epoch}_original.png' in folder '.results'. :param epoch: Number of the current epoch :param num_epochs: Number of total epochs :param opt_state: Current state of the optimizer :param rng: rng key """ assert (num_epochs > 0) img = test_fetch_plain(0, batchifier_state)[0][0] plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_original.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), img, cmap='gray') rng, rng_binarize = random.split(rng, 2) test_sample = binarize(rng_binarize, img) test_sample = jnp.reshape(test_sample, (1, *jnp.shape(test_sample))) params = svi.get_params(svi_state) samples = sample_multi_posterior_predictive( rng, 10, model, (1, args.z_dim, args.hidden_dim, np.prod(test_sample.shape[1:])), guide, (test_sample, args.z_dim, args.hidden_dim), params) img_loc = samples['obs'][0].reshape([28, 28]) avg_img_loc = jnp.mean(samples['obs'], axis=0).reshape([28, 28]) plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_recons_single.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), img_loc, cmap='gray') plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_recons_avg.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), avg_img_loc, cmap='gray') # main training loop for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng, train_rng = random.split(rng, 3) num_train_batches, train_batchifier_state, = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches, train_rng) rng, test_fetch_rng, test_rng, recons_rng = random.split(rng, 4) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches, test_rng) reconstruct_img(i, args.num_epochs, test_batchifier_state, svi_state, recons_rng) print("Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, time.time() - t_start))
def elbo_loss_fn(x): return ELBO().loss(random.PRNGKey(0), {}, model, guide, x)
def elbo_loss_fn(x): return ELBO().loss(random.PRNGKey(0), {}, model, guide, x) def renyi_loss_fn(x): return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x) elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.) renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.) assert_allclose(elbo_loss, renyi_loss, rtol=1e-6) assert_allclose(elbo_grad, renyi_grad, rtol=1e-6) @pytest.mark.parametrize('elbo', [ ELBO(), RenyiELBO(num_particles=10), ]) def test_beta_bernoulli(elbo): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
def main(args): rng = PRNGKey(123) rng, toy_data_rng = jax.random.split(rng) train_data, test_data, true_params = create_toy_data( toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( train_data, batch_size=args.batch_size) test_init, test_fetch = split_batchify_data(test_data, batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples) rng, svi_init_rng, data_fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=data_fetch_rng) sample_batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *sample_batch) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) batch_X, batch_Y = batch svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch, rng): params = svi.get_params(svi_state) def body_fn(i, val): loss_sum, acc_sum = val batch = test_fetch(i, batchifier_state) batch_X, batch_Y = batch loss = svi.evaluate(svi_state, batch_X, batch_Y) loss_sum += loss / (args.num_samples * num_batch) acc_rng = jax.random.fold_in(rng, i) acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1) acc_sum += acc / num_batch return loss_sum, acc_sum return lax.fori_loop(0, num_batch, body_fn, (0., 0.)) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10)) == 0: rng, test_rng, test_fetch_rng = random.split(rng, 3) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss, test_acc = eval_test(svi_state, test_batchifier_state, num_test_batches, test_rng) print( "Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)" .format(i, test_loss, test_acc, train_loss, t_end - t_start)) # parameters for logistic regression may be scaled arbitrarily. normalize # w (and scale intercept accordingly) for comparison w_true = normalize(true_params[0]) scale_true = jnp.linalg.norm(true_params[0]) intercept_true = true_params[1] / scale_true params = svi.get_params(svi_state) w_post = normalize(params['w_loc']) scale_post = jnp.linalg.norm(params['w_loc']) intercept_post = params['intercept_loc'] / scale_post print("w_loc: {}\nexpected: {}\nerror: {}".format( w_post, w_true, jnp.linalg.norm(w_post - w_true))) print("w_std: {}".format(jnp.exp(params['w_std_log']))) print("") print("intercept_loc: {}\nexpected: {}\nerror: {}".format( intercept_post, intercept_true, jnp.abs(intercept_post - intercept_true))) print("intercept_std: {}".format(jnp.exp(params['intercept_std_log']))) print("") X_test, y_test = test_data rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3) # for evaluation accuracy with true parameters, we scale them to the same # scale as the found posterior. (gives better results than normalized # parameters (probably due to numerical instabilities)) acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true, intercept_true, rng_acc_true, 10) acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10) print( "avg accuracy on test set: with true parameters: {} ; with found posterior: {}\n" .format(acc_true, acc_post))
def main(args): N = args.num_samples k = args.num_components d = args.dimensions rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, latent_vals = create_toy_data(toy_data_rng, N, d) train_init, train_fetch = subsample_batchify_data((X_train,), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test,), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) # note(lumip): fix the parameters in the models def fix_params(model_fn, k): def fixed_params_fn(obs, **kwargs): return model_fn(k, obs, **kwargs) return fixed_params_fn model_fixed = fix_params(model, k) guide_fixed = fix_params(guide, k) svi = DPSVI( model_fixed, guide_fixed, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples ) rng, svi_init_rng, fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(fetch_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) @jit def epoch_train(svi_state, data_idx, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update( svi_state, *batch ) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng) svi_state, train_loss = epoch_train( svi_state, train_batchifier_state, num_train_batches ) train_loss.block_until_ready() t_end = time.time() if i % 100 == 0: rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng) test_loss = eval_test( svi_state, test_batchifier_state, num_test_batches ) print("Epoch {}: loss = {} (on training set = {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start )) params = svi.get_params(svi_state) print(params) posterior_modes = params['mus_loc'] posterior_pis = dist.Dirichlet(jnp.exp(params['alpha_log'])).mean print("MAP estimate of mixture weights: {}".format(posterior_pis)) print("MAP estimate of mixture modes : {}".format(posterior_modes)) acc = compute_assignment_accuracy( X_test, latent_vals[1], latent_vals[2], posterior_modes, posterior_pis ) print("assignment accuracy: {}".format(acc))