def test_subsample_batchify_fetch_correct_shape_without_replacement(self): data = np.random.normal(size=(105, 3)) init, fetch = subsample_batchify_data((data, ), 10) batchifier_state = jax.random.PRNGKey(2) # num_batches = 10 batch = fetch(6, batchifier_state) batch = batch[0] self.assertEqual((10, 3), jnp.shape(batch))
def test_subsample_batchify_init_non_divisiable_size(self): data = jnp.arange(0, 105) init, fetch = subsample_batchify_data((data, ), 10) rng_key = jax.random.PRNGKey(0) num_batches, batchifier_state = init(rng_key) self.assertEqual(10, num_batches) self.assertTrue(jnp.allclose(rng_key, batchifier_state))
def test_subsample_batchify_fetch_batches_differ_without_replacement(self): data = np.arange(105) + 100 init, fetch = subsample_batchify_data((data, ), 10) batchifier_state = jax.random.PRNGKey(2) # num_batches = 10 batch_0 = fetch(3, batchifier_state) batch_1 = fetch(8, batchifier_state) self.assertFalse(jnp.allclose(batch_0, batch_1)) # ensure batches are different
def test_subsample_batchify_fetch_with_replacement(self): data = np.arange(105) + 100 init, fetch = subsample_batchify_data((data, ), 10, with_replacement=True) batchifier_state = jax.random.PRNGKey(2) num_batches = 10 for i in range(num_batches): batch = fetch(i, batchifier_state) batch = batch[0] self.assertTrue( np.alltrue(batch >= 100) and np.alltrue( batch < 205)) # ensure batch was plausibly drawn from data
def test_subsample_batchify_fetch_without_replacement(self): data = np.arange(105) + 100 init, fetch = subsample_batchify_data((data, ), 10) batchifier_state = jax.random.PRNGKey(2) num_batches = 10 for i in range(num_batches): batch = fetch(i, batchifier_state) batch = batch[0] _, unq_counts = np.unique(batch, return_counts=True) self.assertTrue( np.alltrue(unq_counts <= 1) ) # ensure each item occurs at most once in the batch self.assertTrue( np.alltrue(batch >= 100) and np.alltrue( batch < 205)) # ensure batch was plausibly drawn from data
def _train_model(rng, rng_suite, svi, data, batch_size, num_data, num_epochs, silent=False): rng, svi_rng, init_batch_rng = rng_suite.split(rng, 3) #init_batching, get_batch = subsample_batchify_data((data,), batch_size) assert (type(data) == tuple) data = _cast_data_tuple(data) init_batching, get_batch = subsample_batchify_data(data, batch_size, rng_suite=rng_suite) _, batchify_state = init_batching(init_batch_rng) batch = get_batch(0, batchify_state) svi_state = svi.init(svi_rng, *batch) @jax.jit def train_epoch(num_epoch_iter, svi_state, batchify_state): def train_iteration(i, state_and_loss): svi_state, loss = state_and_loss batch = get_batch(i, batchify_state) svi_state, iter_loss = svi.update(svi_state, *batch) return (svi_state, loss + iter_loss / num_epoch_iter) return jax.lax.fori_loop(0, num_epoch_iter, train_iteration, (svi_state, 0.)) rng, epochs_rng = rng_suite.split(rng) for i in range(num_epochs): batchify_rng = rng_suite.fold_in(epochs_rng, i) num_batches, batchify_state = init_batching(batchify_rng) svi_state, loss = train_epoch(num_batches, svi_state, batchify_state) if np.isnan(loss): raise InferenceException loss /= num_data if not silent: print("epoch {}: loss {}".format(i, loss)) return svi.get_params(svi_state), loss
def main(args): rng = PRNGKey(123) rng, toy_data_rng = jax.random.split(rng) train_data, test_data, true_params = create_toy_data( toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( train_data, batch_size=args.batch_size) test_init, test_fetch = split_batchify_data(test_data, batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples) rng, svi_init_rng, data_fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=data_fetch_rng) sample_batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *sample_batch) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) batch_X, batch_Y = batch svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch, rng): params = svi.get_params(svi_state) def body_fn(i, val): loss_sum, acc_sum = val batch = test_fetch(i, batchifier_state) batch_X, batch_Y = batch loss = svi.evaluate(svi_state, batch_X, batch_Y) loss_sum += loss / (args.num_samples * num_batch) acc_rng = jax.random.fold_in(rng, i) acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1) acc_sum += acc / num_batch return loss_sum, acc_sum return lax.fori_loop(0, num_batch, body_fn, (0., 0.)) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10)) == 0: rng, test_rng, test_fetch_rng = random.split(rng, 3) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss, test_acc = eval_test(svi_state, test_batchifier_state, num_test_batches, test_rng) print( "Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)" .format(i, test_loss, test_acc, train_loss, t_end - t_start)) # parameters for logistic regression may be scaled arbitrarily. normalize # w (and scale intercept accordingly) for comparison w_true = normalize(true_params[0]) scale_true = jnp.linalg.norm(true_params[0]) intercept_true = true_params[1] / scale_true params = svi.get_params(svi_state) w_post = normalize(params['w_loc']) scale_post = jnp.linalg.norm(params['w_loc']) intercept_post = params['intercept_loc'] / scale_post print("w_loc: {}\nexpected: {}\nerror: {}".format( w_post, w_true, jnp.linalg.norm(w_post - w_true))) print("w_std: {}".format(jnp.exp(params['w_std_log']))) print("") print("intercept_loc: {}\nexpected: {}\nerror: {}".format( intercept_post, intercept_true, jnp.abs(intercept_post - intercept_true))) print("intercept_std: {}".format(jnp.exp(params['intercept_std_log']))) print("") X_test, y_test = test_data rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3) # for evaluation accuracy with true parameters, we scale them to the same # scale as the found posterior. (gives better results than normalized # parameters (probably due to numerical instabilities)) acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true, intercept_true, rng_acc_true, 10) acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10) print( "avg accuracy on test set: with true parameters: {} ; with found posterior: {}\n" .format(acc_true, acc_post))
def main(args): N = args.num_samples k = args.num_components d = args.dimensions rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, latent_vals = create_toy_data(toy_data_rng, N, d) train_init, train_fetch = subsample_batchify_data((X_train,), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test,), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) # note(lumip): fix the parameters in the models def fix_params(model_fn, k): def fixed_params_fn(obs, **kwargs): return model_fn(k, obs, **kwargs) return fixed_params_fn model_fixed = fix_params(model, k) guide_fixed = fix_params(guide, k) svi = DPSVI( model_fixed, guide_fixed, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples ) rng, svi_init_rng, fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(fetch_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) @jit def epoch_train(svi_state, data_idx, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update( svi_state, *batch ) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng) svi_state, train_loss = epoch_train( svi_state, train_batchifier_state, num_train_batches ) train_loss.block_until_ready() t_end = time.time() if i % 100 == 0: rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng) test_loss = eval_test( svi_state, test_batchifier_state, num_test_batches ) print("Epoch {}: loss = {} (on training set = {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start )) params = svi.get_params(svi_state) print(params) posterior_modes = params['mus_loc'] posterior_pis = dist.Dirichlet(jnp.exp(params['alpha_log'])).mean print("MAP estimate of mixture weights: {}".format(posterior_pis)) print("MAP estimate of mixture modes : {}".format(posterior_modes)) acc = compute_assignment_accuracy( X_test, latent_vals[1], latent_vals[2], posterior_modes, posterior_pis ) print("assignment accuracy: {}".format(acc))
def main(args): rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( (X_train, ), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test, ), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=args.sigma, clipping_threshold=args.clip_threshold, d=args.dimensions, num_obs_total=args.num_samples) rng, svi_init_rng, batchifier_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=batchifier_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) q = args.batch_size / args.num_samples eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs) print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format( eps, args.sigma, args.clip_threshold, args.delta, q)) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update(svi_state, *batch) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10) == 0): rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches) print( "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start)) params = svi.get_params(svi_state) mu_loc = params['mu_loc'] mu_std = jnp.exp(params['mu_std_log']) print("### expected: {}".format(mu_true)) print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = analytical_solution(X_train) print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = ml_estimate(X_train) print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))