예제 #1
    def testIssue758(self):
        # code from https://github.com/google/jax/issues/758
        # this is more of a scan + jacfwd/jacrev test, but it lives here to use the
        # optimizers.py code

        def harmonic_bond(conf, params):
            return jnp.sum(conf * params)

        opt_init, opt_update, get_params = optimizers.sgd(5e-2)

        x0 = np.array([0.5], dtype=np.float64)

        def minimize_structure(test_params):
            energy_fn = functools.partial(harmonic_bond, params=test_params)
            grad_fn = grad(energy_fn, argnums=(0, ))
            opt_state = opt_init(x0)

            def apply_carry(carry, _):
                i, x = carry
                g = grad_fn(get_params(x))[0]
                new_state = opt_update(i, g, x)
                new_carry = (i + 1, new_state)
                return new_carry, _

            carry_final, _ = lax.scan(apply_carry, (0, opt_state),
                                      jnp.zeros((75, 0)))
            trip, opt_final = carry_final
            assert trip == 75
            return opt_final

        initial_params = jnp.array(0.5)

        def loss(test_params):
            opt_final = minimize_structure(test_params)
            return 1.0 - get_params(opt_final)[0]

        loss_opt_init, loss_opt_update, loss_get_params = optimizers.sgd(5e-2)

        J1 = jacrev(loss, argnums=(0, ))(initial_params)
        J2 = jacfwd(loss, argnums=(0, ))(initial_params)
        self.assertAllClose(J1, J2, rtol=1e-6)
예제 #2
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset(
        'mnist', _TRAIN_SIZE, _TEST_SIZE)

    # Build the network
    init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(),
                                       stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(_LEARNING_RATE)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0),
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train)

    # Get initial values of the network in function space.
    fx_train = apply_fn(params, x_train)
    fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(_TRAIN_TIME // _LEARNING_RATE)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(_TRAIN_TIME, fx_train, fx_test, g_td)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
예제 #3
  def testTracedStepSize(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    step_size = 0.1

    init_fun, _, _ = optimizers.sgd(step_size)
    opt_state = init_fun(x0)

    def update(opt_state, step_size):
      _, update_fun, get_params = optimizers.sgd(step_size)
      x = get_params(opt_state)
      g = grad(loss)(x)
      return update_fun(0, g, opt_state)

    update(opt_state, 0.9)  # doesn't crash
예제 #4
 def update(opt_state, step_size):
     _, update_fun, get_params = optimizers.sgd(step_size)
     x = get_params(opt_state)
     g = grad(loss)(x)
     return update_fun(0, g, opt_state)
예제 #5
def main(_):

  if FLAGS.microbatches:
    raise NotImplementedError(
        'Microbatches < batch size not currently supported'

  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size)
  num_batches = num_complete_batches + bool(leftover)
  key = random.PRNGKey(FLAGS.seed)

  def data_stream():
    rng = npr.RandomState(FLAGS.seed)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]

  batches = data_stream()

  opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

  def update(_, i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  def private_update(rng, i, opt_state, batch):
    params = get_params(opt_state)
    rng = random.fold_in(rng, i)  # get new key for new random numbers
    return opt_update(
        private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                     FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

  _, init_params = init_random_params(key, (-1, 28, 28, 1))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  steps_per_epoch = 60000 // FLAGS.batch_size
  print('\nStarting training...')
  for epoch in range(1, FLAGS.epochs + 1):
    start_time = time.time()
    for _ in range(num_batches):
      if FLAGS.dpsgd:
        opt_state = \
                key, next(itercount), opt_state,
                shape_as_image(*next(batches), dummy_dim=True))
        opt_state = update(
            key, next(itercount), opt_state, shape_as_image(*next(batches)))
    epoch_time = time.time() - start_time
    print(f'Epoch {epoch} in {epoch_time:0.2f} sec')

    # evaluate test accuracy
    params = get_params(opt_state)
    test_acc = accuracy(params, shape_as_image(test_images, test_labels))
    test_loss = loss(params, shape_as_image(test_images, test_labels))
    print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format(
        test_loss, 100 * test_acc))

    # determine privacy loss so far
    if FLAGS.dpsgd:
      delta = 1e-5
      num_examples = 60000
      eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)
          f'For delta={delta:.0e}, the current epsilon is: {eps:.2f}')
      print('Trained with vanilla non-private SGD optimizer')
예제 #6
class OptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4., 5.]))
        self.per_step_updates = (jnp.array([500.,
                                            5.]), jnp.array([300., 3., 1.]))

        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(
            LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1,
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR_SCHED, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1),
         optimizers.rmsprop(LR, .9, 0.1), 1e-5),
         alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sm3', alias.sm3(LR), optimizers.sm3(LR), 1e-2),
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,

        # example_libraries/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)