def test_split_batchify_fetch(self): data = np.arange(105) + 100 init, fetch = split_batchify_data((data, ), 10) batchifier_state = jax.random.permutation(jax.random.PRNGKey(0), jnp.arange(0, 105)) num_batches = 10 counts = np.zeros(105) for i in range(num_batches): batch = fetch(i, batchifier_state) batch = batch[0] unq_idxs, unq_counts = np.unique(batch, return_counts=True) counts[unq_idxs - 100] = unq_counts self.assertTrue( np.alltrue(unq_counts <= 1) ) # ensure each item occurs at most once in the batch self.assertTrue( np.alltrue(batch >= 100) and np.alltrue( batch < 205)) # ensure batch was plausibly drawn from data self.assertTrue(np.alltrue( counts <= 1)) # ensure each item occurs at most once in the epoch self.assertEqual( 100, np.sum(counts) ) # ensure that amount of elements in batches cover an epoch worth of data
def test_split_batchify_fetch_correct_shape(self): data = np.random.normal(size=(105, 3)) init, fetch = split_batchify_data((data, ), 10) batchifier_state = jax.random.permutation(jax.random.PRNGKey(0), jnp.arange(0, 105)) batch = fetch(6, batchifier_state) batch = batch[0] self.assertEqual((10, 3), jnp.shape(batch))
def test_split_batchify_batches_differ(self): data = np.arange(105) + 100 init, fetch = split_batchify_data((data, ), 10) num_batches, batchifier_state = init(jax.random.PRNGKey(10)) batch_0 = fetch(3, batchifier_state) batch_1 = fetch(8, batchifier_state) self.assertFalse(jnp.allclose(batch_0, batch_1)) # ensure batches are different
def test_split_batchify_init_non_divisiable_size(self): data = jnp.arange(0, 105) init, fetch = split_batchify_data((data, ), 10) rng_key = jax.random.PRNGKey(0) num_batches, batchifier_state = init(rng_key) self.assertEqual(10, num_batches) self.assertTrue( np.alltrue(np.unique(batchifier_state, return_counts=True)[1] < 2))
def test_split_batchify_init(self): data = jnp.arange(0, 100) init, fetch = split_batchify_data((data, ), 10) rng_key = jax.random.PRNGKey(0) num_batches, batchifier_state = init(rng_key) self.assertEqual(10, num_batches) self.assertEqual(jnp.size(data), jnp.size(batchifier_state)) self.assertTrue(np.allclose(np.unique(batchifier_state), data))
def main(args): rng = PRNGKey(123) rng, toy_data_rng = jax.random.split(rng) train_data, test_data, true_params = create_toy_data( toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( train_data, batch_size=args.batch_size) test_init, test_fetch = split_batchify_data(test_data, batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples) rng, svi_init_rng, data_fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=data_fetch_rng) sample_batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *sample_batch) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) batch_X, batch_Y = batch svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch, rng): params = svi.get_params(svi_state) def body_fn(i, val): loss_sum, acc_sum = val batch = test_fetch(i, batchifier_state) batch_X, batch_Y = batch loss = svi.evaluate(svi_state, batch_X, batch_Y) loss_sum += loss / (args.num_samples * num_batch) acc_rng = jax.random.fold_in(rng, i) acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1) acc_sum += acc / num_batch return loss_sum, acc_sum return lax.fori_loop(0, num_batch, body_fn, (0., 0.)) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10)) == 0: rng, test_rng, test_fetch_rng = random.split(rng, 3) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss, test_acc = eval_test(svi_state, test_batchifier_state, num_test_batches, test_rng) print( "Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)" .format(i, test_loss, test_acc, train_loss, t_end - t_start)) # parameters for logistic regression may be scaled arbitrarily. normalize # w (and scale intercept accordingly) for comparison w_true = normalize(true_params[0]) scale_true = jnp.linalg.norm(true_params[0]) intercept_true = true_params[1] / scale_true params = svi.get_params(svi_state) w_post = normalize(params['w_loc']) scale_post = jnp.linalg.norm(params['w_loc']) intercept_post = params['intercept_loc'] / scale_post print("w_loc: {}\nexpected: {}\nerror: {}".format( w_post, w_true, jnp.linalg.norm(w_post - w_true))) print("w_std: {}".format(jnp.exp(params['w_std_log']))) print("") print("intercept_loc: {}\nexpected: {}\nerror: {}".format( intercept_post, intercept_true, jnp.abs(intercept_post - intercept_true))) print("intercept_std: {}".format(jnp.exp(params['intercept_std_log']))) print("") X_test, y_test = test_data rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3) # for evaluation accuracy with true parameters, we scale them to the same # scale as the found posterior. (gives better results than normalized # parameters (probably due to numerical instabilities)) acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true, intercept_true, rng_acc_true, 10) acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10) print( "avg accuracy on test set: with true parameters: {} ; with found posterior: {}\n" .format(acc_true, acc_post))
def main(args): N = args.num_samples k = args.num_components d = args.dimensions rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, latent_vals = create_toy_data(toy_data_rng, N, d) train_init, train_fetch = subsample_batchify_data((X_train,), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test,), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) # note(lumip): fix the parameters in the models def fix_params(model_fn, k): def fixed_params_fn(obs, **kwargs): return model_fn(k, obs, **kwargs) return fixed_params_fn model_fixed = fix_params(model, k) guide_fixed = fix_params(guide, k) svi = DPSVI( model_fixed, guide_fixed, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples ) rng, svi_init_rng, fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(fetch_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) @jit def epoch_train(svi_state, data_idx, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update( svi_state, *batch ) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng) svi_state, train_loss = epoch_train( svi_state, train_batchifier_state, num_train_batches ) train_loss.block_until_ready() t_end = time.time() if i % 100 == 0: rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng) test_loss = eval_test( svi_state, test_batchifier_state, num_test_batches ) print("Epoch {}: loss = {} (on training set = {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start )) params = svi.get_params(svi_state) print(params) posterior_modes = params['mus_loc'] posterior_pis = dist.Dirichlet(jnp.exp(params['alpha_log'])).mean print("MAP estimate of mixture weights: {}".format(posterior_pis)) print("MAP estimate of mixture modes : {}".format(posterior_modes)) acc = compute_assignment_accuracy( X_test, latent_vals[1], latent_vals[2], posterior_modes, posterior_pis ) print("assignment accuracy: {}".format(acc))
def main(args): rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( (X_train, ), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test, ), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=args.sigma, clipping_threshold=args.clip_threshold, d=args.dimensions, num_obs_total=args.num_samples) rng, svi_init_rng, batchifier_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=batchifier_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) q = args.batch_size / args.num_samples eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs) print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format( eps, args.sigma, args.clip_threshold, args.delta, q)) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update(svi_state, *batch) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10) == 0): rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches) print( "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start)) params = svi.get_params(svi_state) mu_loc = params['mu_loc'] mu_std = jnp.exp(params['mu_std_log']) print("### expected: {}".format(mu_true)) print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = analytical_solution(X_train) print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = ml_estimate(X_train) print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))