def testUnpackPackRoundTrip(self): opt_init, _, _ = optimizers.momentum(0.1, mass=0.9) params = [{'w': np.random.randn(1, 2), 'bias': np.random.randn(2)}] expected = opt_init(params) ans = optimizers.pack_optimizer_state( optimizers.unpack_optimizer_state(expected)) self.assertEqual(ans, expected)
def main(unused_argv): # Load data and preprocess it. print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum(_LEARNING_RATE, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // _BATCH_SIZE for i, (x, y) in enumerate(datasets.minibatch( x_train, y_train, _BATCH_SIZE, _TRAIN_EPOCHS)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format( epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary( 'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9): opt_init, opt_update, get_params = optimizers.momentum(step_size, mass) @jit def update(i, opt_state): x = get_params(opt_state) return opt_update(i, grad(f)(x), opt_state) opt_state = opt_init(x) for i in range(num_steps): opt_state = update(i, opt_state) return get_params(opt_state)
inputs, targets = batch logits = predict_fun(params, inputs) return -jnp.sum(logits * targets) def accuracy(params, batch): inputs, targets = batch target_class = jnp.argmax(targets, axis=-1) predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1) return jnp.mean(predicted_class == target_class) def synth_batches(): rng = npr.RandomState(0) while True: images = rng.rand(*input_shape).astype('float32') labels = rng.randint(num_classes, size=(batch_size, 1)) onehot_labels = labels == jnp.arange(num_classes) yield images, onehot_labels opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9) batches = synth_batches() @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) opt_state = opt_init(init_params) for i in range(num_steps): opt_state = update(i, opt_state, next(batches)) trained_params = get_params(opt_state)