def train_model(rng, rng_suite, model, guide, data, batch_size, num_data, dp_scale, num_epochs, clipping_threshold=1.): """ trains a given model using DPSVI and the globally defined parameters and data """ optimizer = Adam(1e-3) svi = DPSVI(model, guide, optimizer, Trace_ELBO(), num_obs_total=num_data, clipping_threshold=clipping_threshold, dp_scale=dp_scale, rng_suite=rng_suite) return _train_model(rng, rng_suite, svi, data, batch_size, num_data, num_epochs)
def setUp(self): self.rng = jax.random.PRNGKey(9782346) self.batch_size = 10 self.num_obs_total = 100 self.px_grads = ((jnp.zeros( (self.batch_size, 10000)), jnp.zeros((self.batch_size, 10000)))) self.px_grads_list, self.tree_def = jax.tree_flatten(self.px_grads) self.px_loss = jnp.arange(self.batch_size, dtype=jnp.float32) self.dp_scale = 1. self.clipping_threshold = 2. self.svi = DPSVI(None, None, None, None, self.clipping_threshold, self.dp_scale, num_obs_total=self.num_obs_total)
def main(args): rng = PRNGKey(123) rng, toy_data_rng = jax.random.split(rng) train_data, test_data, true_params = create_toy_data( toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( train_data, batch_size=args.batch_size) test_init, test_fetch = split_batchify_data(test_data, batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples) rng, svi_init_rng, data_fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=data_fetch_rng) sample_batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *sample_batch) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) batch_X, batch_Y = batch svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch, rng): params = svi.get_params(svi_state) def body_fn(i, val): loss_sum, acc_sum = val batch = test_fetch(i, batchifier_state) batch_X, batch_Y = batch loss = svi.evaluate(svi_state, batch_X, batch_Y) loss_sum += loss / (args.num_samples * num_batch) acc_rng = jax.random.fold_in(rng, i) acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1) acc_sum += acc / num_batch return loss_sum, acc_sum return lax.fori_loop(0, num_batch, body_fn, (0., 0.)) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10)) == 0: rng, test_rng, test_fetch_rng = random.split(rng, 3) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss, test_acc = eval_test(svi_state, test_batchifier_state, num_test_batches, test_rng) print( "Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)" .format(i, test_loss, test_acc, train_loss, t_end - t_start)) # parameters for logistic regression may be scaled arbitrarily. normalize # w (and scale intercept accordingly) for comparison w_true = normalize(true_params[0]) scale_true = jnp.linalg.norm(true_params[0]) intercept_true = true_params[1] / scale_true params = svi.get_params(svi_state) w_post = normalize(params['w_loc']) scale_post = jnp.linalg.norm(params['w_loc']) intercept_post = params['intercept_loc'] / scale_post print("w_loc: {}\nexpected: {}\nerror: {}".format( w_post, w_true, jnp.linalg.norm(w_post - w_true))) print("w_std: {}".format(jnp.exp(params['w_std_log']))) print("") print("intercept_loc: {}\nexpected: {}\nerror: {}".format( intercept_post, intercept_true, jnp.abs(intercept_post - intercept_true))) print("intercept_std: {}".format(jnp.exp(params['intercept_std_log']))) print("") X_test, y_test = test_data rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3) # for evaluation accuracy with true parameters, we scale them to the same # scale as the found posterior. (gives better results than normalized # parameters (probably due to numerical instabilities)) acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true, intercept_true, rng_acc_true, 10) acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10) print( "avg accuracy on test set: with true parameters: {} ; with found posterior: {}\n" .format(acc_true, acc_post))
def main(args): N = args.num_samples k = args.num_components d = args.dimensions rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, latent_vals = create_toy_data(toy_data_rng, N, d) train_init, train_fetch = subsample_batchify_data((X_train,), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test,), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) # note(lumip): fix the parameters in the models def fix_params(model_fn, k): def fixed_params_fn(obs, **kwargs): return model_fn(k, obs, **kwargs) return fixed_params_fn model_fixed = fix_params(model, k) guide_fixed = fix_params(guide, k) svi = DPSVI( model_fixed, guide_fixed, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples ) rng, svi_init_rng, fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(fetch_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) @jit def epoch_train(svi_state, data_idx, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update( svi_state, *batch ) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng) svi_state, train_loss = epoch_train( svi_state, train_batchifier_state, num_train_batches ) train_loss.block_until_ready() t_end = time.time() if i % 100 == 0: rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng) test_loss = eval_test( svi_state, test_batchifier_state, num_test_batches ) print("Epoch {}: loss = {} (on training set = {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start )) params = svi.get_params(svi_state) print(params) posterior_modes = params['mus_loc'] posterior_pis = dist.Dirichlet(jnp.exp(params['alpha_log'])).mean print("MAP estimate of mixture weights: {}".format(posterior_pis)) print("MAP estimate of mixture modes : {}".format(posterior_modes)) acc = compute_assignment_accuracy( X_test, latent_vals[1], latent_vals[2], posterior_modes, posterior_pis ) print("assignment accuracy: {}".format(acc))
def main(args): rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( (X_train, ), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test, ), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=args.sigma, clipping_threshold=args.clip_threshold, d=args.dimensions, num_obs_total=args.num_samples) rng, svi_init_rng, batchifier_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=batchifier_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) q = args.batch_size / args.num_samples eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs) print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format( eps, args.sigma, args.clip_threshold, args.delta, q)) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update(svi_state, *batch) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10) == 0): rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches) print( "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start)) params = svi.get_params(svi_state) mu_loc = params['mu_loc'] mu_std = jnp.exp(params['mu_std_log']) print("### expected: {}".format(mu_true)) print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = analytical_solution(X_train) print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = ml_estimate(X_train) print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
def main(args): # loading data (train_init, train_fetch_plain), num_samples = load_dataset( MNIST, batch_size=args.batch_size, split='train', batchifier=subsample_batchify_data) (test_init, test_fetch_plain), _ = load_dataset(MNIST, batch_size=args.batch_size, split='test', batchifier=split_batchify_data) def binarize_fetch(fetch_fn): @jit def fetch_binarized(batch_nr, batchifier_state, binarize_rng): batch = fetch_fn(batch_nr, batchifier_state) return binarize(binarize_rng, batch[0]), batch[1] return fetch_binarized train_fetch = binarize_fetch(train_fetch_plain) test_fetch = binarize_fetch(test_fetch_plain) # setting up optimizer optimizer = optimizers.Adam(args.learning_rate) # the minibatch environment in our model scales individual # records' contributions to the loss up by num_samples. # This can cause numerical instabilities so we scale down # the loss by 1/num_samples here. sc_model = scale(model, scale=1 / num_samples) sc_guide = scale(guide, scale=1 / num_samples) if args.no_dp: svi = SVI(sc_model, sc_guide, optimizer, ELBO(), num_obs_total=num_samples, z_dim=args.z_dim, hidden_dim=args.hidden_dim) else: q = args.batch_size / num_samples target_eps = args.epsilon dp_scale, act_eps, _ = approximate_sigma(target_eps=target_eps, delta=1 / num_samples, q=q, num_iter=int(1 / q) * args.num_epochs, force_smaller=True) print( f"using noise scale {dp_scale} for epsilon of {act_eps} (targeted: {target_eps})" ) svi = DPSVI(sc_model, sc_guide, optimizer, ELBO(), dp_scale=dp_scale, clipping_threshold=10., num_obs_total=num_samples, z_dim=args.z_dim, hidden_dim=args.hidden_dim) # preparing random number generators and initializing svi rng = PRNGKey(0) rng, binarize_rng, svi_init_rng, batchifier_rng = random.split(rng, 4) _, batchifier_state = train_init(rng_key=batchifier_rng) sample_batch = train_fetch(0, batchifier_state, binarize_rng)[0] svi_state = svi.init(svi_init_rng, sample_batch) # functions for training tasks @jit def epoch_train(svi_state, batchifier_state, num_batches, rng): """Trains one epoch :param svi_state: current state of the optimizer :param rng: rng key :return: overall training loss over the epoch """ def body_fn(i, val): svi_state, loss = val binarize_rng = random.fold_in(rng, i) batch = train_fetch(i, batchifier_state, binarize_rng)[0] svi_state, batch_loss = svi.update(svi_state, batch) loss += batch_loss / num_batches return svi_state, loss svi_state, loss = lax.fori_loop(0, num_batches, body_fn, (svi_state, 0.)) return svi_state, loss @jit def eval_test(svi_state, batchifier_state, num_batches, rng): """Evaluates current model state on test data. :param svi_state: current state of the optimizer :param rng: rng key :return: loss over the test split """ def body_fn(i, loss_sum): binarize_rng = random.fold_in(rng, i) batch = test_fetch(i, batchifier_state, binarize_rng)[0] batch_loss = svi.evaluate(svi_state, batch) loss_sum += batch_loss / num_batches return loss_sum return lax.fori_loop(0, num_batches, body_fn, 0.) def reconstruct_img(epoch, num_epochs, batchifier_state, svi_state, rng): """Reconstructs an image for the given epoch Obtains a sample from the testing data set and passes it through the VAE. Stores the result as image file 'epoch_{epoch}_recons.png' and the original input as 'epoch_{epoch}_original.png' in folder '.results'. :param epoch: Number of the current epoch :param num_epochs: Number of total epochs :param opt_state: Current state of the optimizer :param rng: rng key """ assert (num_epochs > 0) img = test_fetch_plain(0, batchifier_state)[0][0] plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_original.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), img, cmap='gray') rng, rng_binarize = random.split(rng, 2) test_sample = binarize(rng_binarize, img) test_sample = jnp.reshape(test_sample, (1, *jnp.shape(test_sample))) params = svi.get_params(svi_state) samples = sample_multi_posterior_predictive( rng, 10, model, (1, args.z_dim, args.hidden_dim, np.prod(test_sample.shape[1:])), guide, (test_sample, args.z_dim, args.hidden_dim), params) img_loc = samples['obs'][0].reshape([28, 28]) avg_img_loc = jnp.mean(samples['obs'], axis=0).reshape([28, 28]) plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_recons_single.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), img_loc, cmap='gray') plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_recons_avg.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), avg_img_loc, cmap='gray') # main training loop for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng, train_rng = random.split(rng, 3) num_train_batches, train_batchifier_state, = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches, train_rng) rng, test_fetch_rng, test_rng, recons_rng = random.split(rng, 4) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches, test_rng) reconstruct_img(i, args.num_epochs, test_batchifier_state, svi_state, recons_rng) print("Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, time.time() - t_start))
class DPSVITest(unittest.TestCase): def setUp(self): self.rng = jax.random.PRNGKey(9782346) self.batch_size = 10 self.num_obs_total = 100 self.px_grads = ((jnp.zeros( (self.batch_size, 10000)), jnp.zeros((self.batch_size, 10000)))) self.px_grads_list, self.tree_def = jax.tree_flatten(self.px_grads) self.px_loss = jnp.arange(self.batch_size, dtype=jnp.float32) self.dp_scale = 1. self.clipping_threshold = 2. self.svi = DPSVI(None, None, None, None, self.clipping_threshold, self.dp_scale, num_obs_total=self.num_obs_total) def test_px_gradient_aggregation(self): svi_state = DPSVIState(None, self.rng, .3) np.random.seed(0) px_grads_list, _testMethodDoc = jax.tree_flatten( (np.random.normal(1, 1, size=(self.batch_size, 10000)), np.random.normal(1, 1, size=(self.batch_size, 10000)))) expected_grads_list = [ jnp.mean(px_grads, axis=0) for px_grads in px_grads_list ] expected_loss = jnp.mean(self.px_loss) loss, grads_list = self.svi._combine_gradients(px_grads_list, self.px_loss) self.assertTrue(jnp.allclose(expected_loss, loss), f"expected loss {expected_loss} but was {loss}") self.assertTrue( jnp.allclose(expected_grads_list, grads_list), f"expected gradients {expected_grads_list} but was {grads_list}") def test_dp_noise_perturbation(self): svi_state = DPSVIState(None, self.rng, .3) grads_list = [ jnp.mean(px_grads, axis=0) for px_grads in self.px_grads_list ] new_svi_state, grads = \ self.svi._perturb_and_reassemble_gradients( svi_state, grads_list, self.batch_size, self.tree_def ) self.assertIs(svi_state.optim_state, new_svi_state.optim_state) self.assertFalse(jnp.allclose(svi_state.rng_key, new_svi_state.rng_key)) self.assertEqual(self.tree_def, jax.tree_structure(grads)) expected_std = (self.dp_scale * self.clipping_threshold / self.batch_size) * svi_state.observation_scale for site, px_site in zip(jax.tree_leaves(grads), jax.tree_leaves(self.px_grads_list)): self.assertEqual(px_site.shape[1:], site.shape) self.assertTrue( jnp.allclose(expected_std, jnp.std(site), atol=1e-2), f"expected stdev {expected_std} but was {jnp.std(site)}") self.assertTrue(jnp.allclose(0., jnp.mean(site), atol=1e-2)) def test_dp_noise_perturbation_not_deterministic_over_calls(self): """ verifies that different randomness is used in subsequent calls """ svi_state = DPSVIState(None, self.rng, .3) grads_list = [ jnp.mean(px_grads, axis=0) for px_grads in self.px_grads_list ] new_svi_state, first_grads = \ self.svi._perturb_and_reassemble_gradients( svi_state, grads_list, self.batch_size, self.tree_def ) _, second_grads = \ self.svi._perturb_and_reassemble_gradients( new_svi_state, grads_list, self.batch_size, self.tree_def ) some_gradient_noise_is_equal = reduce( lambda are_equal, acc: are_equal or acc, jax.tree_leaves( jax.tree_multimap(lambda x, y: jnp.allclose(x, y), first_grads, second_grads))) self.assertFalse(some_gradient_noise_is_equal) def test_dp_noise_perturbation_not_deterministic_over_sites(self): """ verifies that different randomness is used for different gradient sites """ svi_state = DPSVIState(None, self.rng, .3) grads_list = [ jnp.mean(px_grads, axis=0) for px_grads in self.px_grads_list ] _, grads = \ self.svi._perturb_and_reassemble_gradients( svi_state, grads_list, self.batch_size, self.tree_def ) noise_sites = jax.tree_leaves(grads) self.assertFalse(jnp.allclose(noise_sites[0], noise_sites[1]))