def model(data, mask): with numpyro.plate('N', N): x = numpyro.sample('x', dist.Normal(0, 1)) with handlers.mask(mask=mask): numpyro.sample('y', dist.Delta(x, log_density=1.)) with handlers.scale(scale=2): numpyro.sample('obs', dist.Normal(x, 1), obs=data)
def model(data, mask): with numpyro.plate("N", N): x = numpyro.sample("x", dist.Normal(0, 1)) with handlers.mask(mask=mask): numpyro.sample("y", dist.Delta(x, log_density=1.0)) with handlers.scale(scale=2): numpyro.sample("obs", dist.Normal(x, 1), obs=data)
def scaled_loss(rng_key, classic_params, stein_params): params = {**classic_params, **stein_params} loss_val = self.loss.loss( rng_key, params, handlers.scale(self.model, self.loss_temperature), self.guide, *args, **kwargs, **self.static_kwargs) return -loss_val
def minibatch(batch_or_batchsize, num_obs_total=None): """Returns a context within which all samples are treated as being a minibatch of a larger data set. In essence, this marks the (log)likelihood of the sampled examples to be scaled to the total loss value over the whole data set. :param batch_or_batchsize: An integer indicating the batch size or an array indicating the shape of the batch where the length of the first axis is interpreted as batch size. :param num_obs_total: The total number of examples/observations in the full data set. Optional, defaults to the given batch size. """ if is_int_scalar(batch_or_batchsize): if not jnp.isscalar(batch_or_batchsize): raise TypeError( "if a scalar is given for batch_or_batchsize, it " "can't be traced through jit. consider using static_argnums " "for the jit invocation.") batch_size = batch_or_batchsize elif is_array(batch_or_batchsize): batch_size = example_count(batch_or_batchsize) else: raise TypeError("batch_or_batchsize must be an array or an integer") if num_obs_total is None: num_obs_total = batch_size return scale(scale=num_obs_total / batch_size)
def test_scale(use_context_manager): def model(data): x = numpyro.sample('x', dist.Normal(0, 1)) with optional(use_context_manager, handlers.scale(scale=10)): numpyro.sample('obs', dist.Normal(x, 1), obs=data) model = model if use_context_manager else handlers.scale(model, 10.) data = random.normal(random.PRNGKey(0), (3, )) x = random.normal(random.PRNGKey(1)) log_joint = log_density(model, (data, ), {}, {'x': x})[0] log_prob1, log_prob2 = dist.Normal(0, 1).log_prob(x), dist.Normal( x, 1).log_prob(data).sum() expected = log_prob1 + 10 * log_prob2 if use_context_manager else 10 * ( log_prob1 + log_prob2) assert_allclose(log_joint, expected)
def model(data): x = numpyro.sample('x', dist.Normal(0, 1)) with optional(use_context_manager, handlers.scale(scale=10)): numpyro.sample('obs', dist.Normal(x, 1), obs=data)
def test_subsample_gradient(scale, subsample): data = jnp.array([-0.5, 2.0]) subsample_size = 1 if subsample else len(data) precision = 0.06 * scale def model(subsample): with handlers.substitute(data={"data": subsample}): with numpyro.plate("data", len(data), subsample_size) as ind: x = data[ind] z = numpyro.sample("z", dist.Normal(0, 1)) numpyro.sample("x", dist.Normal(z, 1), obs=x) def guide(subsample): scale = numpyro.param("scale", 1.) with handlers.substitute(data={"data": subsample}): with numpyro.plate("data", len(data), subsample_size): loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0) numpyro.sample("z", dist.Normal(loc, scale)) if scale != 1.: model = handlers.scale(model, scale=scale) guide = handlers.scale(guide, scale=scale) num_particles = 50000 optimizer = optim.Adam(0.1) elbo = Trace_ELBO(num_particles=num_particles) svi = SVI(model, guide, optimizer, loss=elbo) svi_state = svi.init(random.PRNGKey(0), None) params = svi.optim.get_params(svi_state.optim_state) normalizer = 2 if subsample else 1 if subsample_size == 1: subsample = jnp.array([0]) loss1, grads1 = value_and_grad( lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi .model, svi.guide, subsample))(params) subsample = jnp.array([1]) loss2, grads2 = value_and_grad( lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi .model, svi.guide, subsample))(params) grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2) loss = loss1 + loss2 else: subsample = jnp.array([0, 1]) loss, grads = value_and_grad( lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi .model, svi.guide, subsample))(params) actual_loss = loss / normalizer expected_loss, _ = value_and_grad(lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, None))( params) assert_allclose(actual_loss, expected_loss, rtol=precision, atol=precision) actual_grads = {name: grad / normalizer for name, grad in grads.items()} expected_grads = { 'loc': scale * jnp.array([0.5, -2.0]), 'scale': scale * jnp.array([2.0]) } assert actual_grads.keys() == expected_grads.keys() for name in expected_grads: assert_allclose(actual_grads[name], expected_grads[name], rtol=precision, atol=precision)
def model(data): x = sample('x', dist.Normal(0, 1)) with scale(10): sample('obs', dist.Normal(x, 1), obs=data)
def main(args): # loading data (train_init, train_fetch_plain), num_samples = load_dataset( MNIST, batch_size=args.batch_size, split='train', batchifier=subsample_batchify_data) (test_init, test_fetch_plain), _ = load_dataset(MNIST, batch_size=args.batch_size, split='test', batchifier=split_batchify_data) def binarize_fetch(fetch_fn): @jit def fetch_binarized(batch_nr, batchifier_state, binarize_rng): batch = fetch_fn(batch_nr, batchifier_state) return binarize(binarize_rng, batch[0]), batch[1] return fetch_binarized train_fetch = binarize_fetch(train_fetch_plain) test_fetch = binarize_fetch(test_fetch_plain) # setting up optimizer optimizer = optimizers.Adam(args.learning_rate) # the minibatch environment in our model scales individual # records' contributions to the loss up by num_samples. # This can cause numerical instabilities so we scale down # the loss by 1/num_samples here. sc_model = scale(model, scale=1 / num_samples) sc_guide = scale(guide, scale=1 / num_samples) if args.no_dp: svi = SVI(sc_model, sc_guide, optimizer, ELBO(), num_obs_total=num_samples, z_dim=args.z_dim, hidden_dim=args.hidden_dim) else: q = args.batch_size / num_samples target_eps = args.epsilon dp_scale, act_eps, _ = approximate_sigma(target_eps=target_eps, delta=1 / num_samples, q=q, num_iter=int(1 / q) * args.num_epochs, force_smaller=True) print( f"using noise scale {dp_scale} for epsilon of {act_eps} (targeted: {target_eps})" ) svi = DPSVI(sc_model, sc_guide, optimizer, ELBO(), dp_scale=dp_scale, clipping_threshold=10., num_obs_total=num_samples, z_dim=args.z_dim, hidden_dim=args.hidden_dim) # preparing random number generators and initializing svi rng = PRNGKey(0) rng, binarize_rng, svi_init_rng, batchifier_rng = random.split(rng, 4) _, batchifier_state = train_init(rng_key=batchifier_rng) sample_batch = train_fetch(0, batchifier_state, binarize_rng)[0] svi_state = svi.init(svi_init_rng, sample_batch) # functions for training tasks @jit def epoch_train(svi_state, batchifier_state, num_batches, rng): """Trains one epoch :param svi_state: current state of the optimizer :param rng: rng key :return: overall training loss over the epoch """ def body_fn(i, val): svi_state, loss = val binarize_rng = random.fold_in(rng, i) batch = train_fetch(i, batchifier_state, binarize_rng)[0] svi_state, batch_loss = svi.update(svi_state, batch) loss += batch_loss / num_batches return svi_state, loss svi_state, loss = lax.fori_loop(0, num_batches, body_fn, (svi_state, 0.)) return svi_state, loss @jit def eval_test(svi_state, batchifier_state, num_batches, rng): """Evaluates current model state on test data. :param svi_state: current state of the optimizer :param rng: rng key :return: loss over the test split """ def body_fn(i, loss_sum): binarize_rng = random.fold_in(rng, i) batch = test_fetch(i, batchifier_state, binarize_rng)[0] batch_loss = svi.evaluate(svi_state, batch) loss_sum += batch_loss / num_batches return loss_sum return lax.fori_loop(0, num_batches, body_fn, 0.) def reconstruct_img(epoch, num_epochs, batchifier_state, svi_state, rng): """Reconstructs an image for the given epoch Obtains a sample from the testing data set and passes it through the VAE. Stores the result as image file 'epoch_{epoch}_recons.png' and the original input as 'epoch_{epoch}_original.png' in folder '.results'. :param epoch: Number of the current epoch :param num_epochs: Number of total epochs :param opt_state: Current state of the optimizer :param rng: rng key """ assert (num_epochs > 0) img = test_fetch_plain(0, batchifier_state)[0][0] plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_original.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), img, cmap='gray') rng, rng_binarize = random.split(rng, 2) test_sample = binarize(rng_binarize, img) test_sample = jnp.reshape(test_sample, (1, *jnp.shape(test_sample))) params = svi.get_params(svi_state) samples = sample_multi_posterior_predictive( rng, 10, model, (1, args.z_dim, args.hidden_dim, np.prod(test_sample.shape[1:])), guide, (test_sample, args.z_dim, args.hidden_dim), params) img_loc = samples['obs'][0].reshape([28, 28]) avg_img_loc = jnp.mean(samples['obs'], axis=0).reshape([28, 28]) plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_recons_single.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), img_loc, cmap='gray') plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_recons_avg.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), avg_img_loc, cmap='gray') # main training loop for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng, train_rng = random.split(rng, 3) num_train_batches, train_batchifier_state, = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches, train_rng) rng, test_fetch_rng, test_rng, recons_rng = random.split(rng, 4) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches, test_rng) reconstruct_img(i, args.num_epochs, test_batchifier_state, svi_state, recons_rng) print("Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, time.time() - t_start))