def testIssue758(self): # code from https://github.com/google/jax/issues/758 # this is more of a scan + jacfwd/jacrev test, but it lives here to use the # optimizers.py code def harmonic_bond(conf, params): return jnp.sum(conf * params) opt_init, opt_update, get_params = optimizers.sgd(5e-2) x0 = np.array([0.5], dtype=np.float64) def minimize_structure(test_params): energy_fn = functools.partial(harmonic_bond, params=test_params) grad_fn = grad(energy_fn, argnums=(0, )) opt_state = opt_init(x0) def apply_carry(carry, _): i, x = carry g = grad_fn(get_params(x))[0] new_state = opt_update(i, g, x) new_carry = (i + 1, new_state) return new_carry, _ carry_final, _ = lax.scan(apply_carry, (0, opt_state), jnp.zeros((75, 0))) trip, opt_final = carry_final assert trip == 75 return opt_final initial_params = jnp.array(0.5) minimize_structure(initial_params) def loss(test_params): opt_final = minimize_structure(test_params) return 1.0 - get_params(opt_final)[0] loss_opt_init, loss_opt_update, loss_get_params = optimizers.sgd(5e-2) J1 = jacrev(loss, argnums=(0, ))(initial_params) J2 = jacfwd(loss, argnums=(0, ))(initial_params) self.assertAllClose(J1, J2, rtol=1e-6)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset( 'mnist', _TRAIN_SIZE, _TEST_SIZE) # Build the network init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(_LEARNING_RATE) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0), batch_size=64, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(_TRAIN_TIME // _LEARNING_RATE) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(_TRAIN_TIME, fx_train, fx_test, g_td) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def testTracedStepSize(self): def loss(x): return jnp.dot(x, x) x0 = jnp.ones(2) step_size = 0.1 init_fun, _, _ = optimizers.sgd(step_size) opt_state = init_fun(x0) @jit def update(opt_state, step_size): _, update_fun, get_params = optimizers.sgd(step_size) x = get_params(opt_state) g = grad(loss)(x) return update_fun(0, g, opt_state) update(opt_state, 0.9) # doesn't crash
def update(opt_state, step_size): _, update_fun, get_params = optimizers.sgd(step_size) x = get_params(opt_state) g = grad(loss)(x) return update_fun(0, g, opt_state)
def main(_): if FLAGS.microbatches: raise NotImplementedError( 'Microbatches < batch size not currently supported' ) train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) key = random.PRNGKey(FLAGS.seed) def data_stream(): rng = npr.RandomState(FLAGS.seed) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) _, init_params = init_random_params(key, (-1, 28, 28, 1)) opt_state = opt_init(init_params) itercount = itertools.count() steps_per_epoch = 60000 // FLAGS.batch_size print('\nStarting training...') for epoch in range(1, FLAGS.epochs + 1): start_time = time.time() for _ in range(num_batches): if FLAGS.dpsgd: opt_state = \ private_update( key, next(itercount), opt_state, shape_as_image(*next(batches), dummy_dim=True)) else: opt_state = update( key, next(itercount), opt_state, shape_as_image(*next(batches))) epoch_time = time.time() - start_time print(f'Epoch {epoch} in {epoch_time:0.2f} sec') # evaluate test accuracy params = get_params(opt_state) test_acc = accuracy(params, shape_as_image(test_images, test_labels)) test_loss = loss(params, shape_as_image(test_images, test_labels)) print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format( test_loss, 100 * test_acc)) # determine privacy loss so far if FLAGS.dpsgd: delta = 1e-5 num_examples = 60000 eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta) print( f'For delta={delta:.0e}, the current epsilon is: {eps:.2f}') else: print('Trained with vanilla non-private SGD optimizer')
class OptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4., 5.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3., 1.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop( LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR_SCHED, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sm3', alias.sm3(LR), optimizers.sm3(LR), 1e-2), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # example_libraries/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)