Example #1
0
 def testUnpackPackRoundTrip(self):
     opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)
     params = [{'w': np.random.randn(1, 2), 'bias': np.random.randn(2)}]
     expected = opt_init(params)
     ans = optimizers.pack_optimizer_state(
         optimizers.unpack_optimizer_state(expected))
     self.assertEqual(ans, expected)
Example #2
0
def main(unused_argv):
  # Load data and preprocess it.
  print('Loading data.')
  x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                          permute_train=True)

  # Build the network
  init_fn, f, _ = stax.serial(
      stax.Dense(512, 1., 0.05),
      stax.Erf(),
      stax.Dense(10, 1., 0.05))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Linearize the network about its initial parameters.
  f_lin = nt.linearize(f, params)

  # Create and initialize an optimizer for both f and f_lin.
  opt_init, opt_apply, get_params = optimizers.momentum(_LEARNING_RATE, 0.9)
  opt_apply = jit(opt_apply)

  state = opt_init(params)
  state_lin = opt_init(params)

  # Create a cross-entropy loss function.
  loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)

  # Specialize the loss function to compute gradients for both linearized and
  # full networks.
  grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
  grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

  # Train the network.
  print('Training.')
  print('Epoch\tLoss\tLinearized Loss')
  print('------------------------------------------')

  epoch = 0
  steps_per_epoch = 50000 // _BATCH_SIZE

  for i, (x, y) in enumerate(datasets.minibatch(
      x_train, y_train, _BATCH_SIZE, _TRAIN_EPOCHS)):

    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x, y), state)

    params_lin = get_params(state_lin)
    state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

    if i % steps_per_epoch == 0:
      print('{}\t{:.4f}\t{:.4f}'.format(
          epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y)))
      epoch += 1

  # Print out summary data comparing the linear / nonlinear model.
  x, y = x_train[:10000], y_train[:10000]
  util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
  util.print_summary(
      'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
Example #3
0
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9):
  opt_init, opt_update, get_params = optimizers.momentum(step_size, mass)

  @jit
  def update(i, opt_state):
    x = get_params(opt_state)
    return opt_update(i, grad(f)(x), opt_state)

  opt_state = opt_init(x)
  for i in range(num_steps):
    opt_state = update(i, opt_state)
  return get_params(opt_state)
Example #4
0
    inputs, targets = batch
    logits = predict_fun(params, inputs)
    return -jnp.sum(logits * targets)

  def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=-1)
    predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1)
    return jnp.mean(predicted_class == target_class)

  def synth_batches():
    rng = npr.RandomState(0)
    while True:
      images = rng.rand(*input_shape).astype('float32')
      labels = rng.randint(num_classes, size=(batch_size, 1))
      onehot_labels = labels == jnp.arange(num_classes)
      yield images, onehot_labels

  opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
  batches = synth_batches()

  @jit
  def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  opt_state = opt_init(init_params)
  for i in range(num_steps):
    opt_state = update(i, opt_state, next(batches))
  trained_params = get_params(opt_state)