Example #1
0
class ExperimentalOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(
            LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1,
                                           momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR_SCHED, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1),
         optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum',
         alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR_SCHED,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
Example #2
0
def args_to_op(optimizer_string, lr, mom=0.9, var=0.999, eps=1e-7):
    return {
        "gd": lambda lr, *unused: op.sgd(lr),
        "sgd": lambda lr, *unused: op.sgd(lr),
        "momentum": lambda lr, mom, *unused: op.momentum(lr, mom),
        "adam": lambda lr, mom, var, eps: op.adam(lr, mom, var, eps),
    }[optimizer_string.lower()](lr, mom, var, eps)
Example #3
0
def _define_optimizer_JAX(dataset, learn_rate):
    if 'purchases' == dataset:
        return optimizers.adam(learn_rate)
        # return optimizers.sgd(learn_rate)       # temporarily error....
    elif 'fashion_mnist' == dataset:
        return optimizers.sgd(learn_rate)
    elif 'cifar10' == dataset:
        return optimizers.sgd(learn_rate)
    else:
        assert False, ('Error: unknown dataset - {}'.format(dataset))
Example #4
0
def train_devise(
    weights,
    X,
    y,
    label_embeddings,
    margin,
    n_epochs=2,
    learning_rate=0.001,
    batch_size=16,
):
    opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
    opt_state = opt_init(weights)
    loader = torch_util.get_dataloader(np.array(X),
                                       np.array(y),
                                       batch_size=batch_size)

    loss_fn = get_similarity_based_hinge_loss(label_embeddings, X, y, margin)
    gradient_fn = jax.jit(jax.grad(loss_fn))

    for i in tqdm.tqdm(range(n_epochs)):
        for (j, (embeddings_np, labels_np)) in enumerate(loader):
            embeddings, labels = jnp.array(embeddings_np), jnp.array(labels_np)
            grads = gradient_fn(weights)
            opt_state = opt_update(j, grads, opt_state)
            weights = get_params(opt_state)
    return weights
Example #5
0
def fit_mixture(data,
                num_components=3,
                verbose=False,
                num_samples=5000) -> LogisticMixtureParams:
    # the data might be something weird, like a pandas dataframe column;
    # turn it into a regular old numpy array
    data_as_np_array = np.array(data)
    step_size = 0.01
    components = initialize_components(num_components)
    (init_fun, update_fun, get_params) = sgd(step_size)
    opt_state = init_fun(components)
    for i in tqdm(range(num_samples)):
        components = get_params(opt_state)
        grads = -grad_mixture_logpdf(data_as_np_array, components)
        if np.any(np.isnan(grads)):
            print("Encoutered nan gradient, stopping early")
            print(grads)
            print(components)
            break
        grads = clip_grads(grads, 1.0)
        opt_state = update_fun(i, grads, opt_state)
        if i % 500 == 0 and verbose:
            pprint(components)
            score = mixture_logpdf(data_as_np_array, components)
            print(f"Log score: {score:.3f}")
    return structure_mixture_params(components)
Example #6
0
    def test_optimize_rotation(self):

        opt_init, opt_update, get_params = optimizers.sgd(5e-2)

        x0 = onp.array([
            [1.0, 0.2, 3.3], # H
            [-0.6,-1.1,-0.9],# C
            [3.4, 5.5, 0.2], # H
            [3.6, 5.6, 0.6], # H
        ], dtype=onp.float64)

        x1 = onp.array([
            [1.0, 0.2, 3.3], # H
            [-0.6,-1.1,-0.9],# C
            [3.4, 5.5, 0.2], # H
            [3.6, 5.6, 0.6], # H
        ], dtype=onp.float64) + onp.random.rand(x0.shape[0],3)*10

        grad_fn = jax.jit(jax.grad(rmsd.opt_rot_rmsd, argnums=(0,)))
        opt_state = opt_init(x0)

        for i in range(1500):
            g = grad_fn(get_params(opt_state), x1)[0]
            opt_state = opt_update(i, g, opt_state)

        x_final = get_params(opt_state)

        assert rmsd.opt_rot_rmsd(x_final, x1) < 0.1
Example #7
0
    def testMaxLearningRate(self, train_shape, network, out_logits,
                            fn_and_kernel):

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        if len(train_shape) == 2:
            train_shape = (train_shape[0] * 5, train_shape[1] * 10)
        else:
            train_shape = (16, 8, 8, 3)
        x_train = np.asarray(normal(train_shape, seed=split))

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        y_train = np.asarray(
            stateless_uniform(shape=(train_shape[0], out_logits),
                              seed=split,
                              minval=0,
                              maxval=1) < 0.5, np.float32)
        # Regress to an MSE loss.
        loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2)
        grad_loss = jit(grad(loss))

        def get_loss(opt_state):
            return loss(get_params(opt_state), x_train)

        steps = 20

        for lr_factor in [0.5, 3.]:
            params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                           out_logits)
            g_dd = ntk(x_train, None, 'ntk')

            step_size = predict.max_learning_rate(
                g_dd, y_train_size=y_train.size) * lr_factor
            opt_init, opt_update, get_params = optimizers.sgd(step_size)
            opt_state = opt_init(params)

            init_loss = get_loss(opt_state)

            for i in range(steps):
                params = get_params(opt_state)
                opt_state = opt_update(i, grad_loss(params, x_train),
                                       opt_state)

            trained_loss = get_loss(opt_state)
            loss_ratio = trained_loss / (init_loss + 1e-12)
            if lr_factor == 3.:
                if not math.isnan(loss_ratio):
                    self.assertGreater(loss_ratio, 10.)
            else:
                self.assertLess(loss_ratio, 0.1)
  def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits,
                          fn_and_kernel, momentum, learning_rate, t, loss):
    key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                     train_shape)

    params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits)

    g_dd = ntk(x_train, None, 'ntk')
    g_td = ntk(x_test, x_train, 'ntk')

    # Regress to an MSE loss.
    loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
    grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train)))

    trace_axes = () if g_dd.ndim == 4 else (-1,)
    if loss == 'mse_analytic':
      if momentum is not None:
        raise absltest.SkipTest(momentum)
      predictor = predict.gradient_descent_mse(g_dd, y_train,
                                               learning_rate=learning_rate,
                                               trace_axes=trace_axes)
    elif loss == 'mse':
      predictor = predict.gradient_descent(loss_fn, g_dd, y_train,
                                           learning_rate=learning_rate,
                                           momentum=momentum,
                                           trace_axes=trace_axes)
    else:
      raise NotImplementedError(loss)

    predictor = jit(predictor)

    fx_train_0 = f(params, x_train)
    fx_test_0 = f(params, x_test)

    self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum)
    self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum)
    if loss == 'mse_analytic':
      self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td, y_train)

    if momentum is None:
      opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
    else:
      opt_init, opt_update, get_params = optimizers.momentum(learning_rate,
                                                             momentum)

    opt_state = opt_init(params)
    for i in range(t):
      params = get_params(opt_state)
      opt_state = opt_update(i, grad_loss(params, x_train), opt_state)

    params = get_params(opt_state)

    fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test)
    fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td)

    self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL)
    self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
Example #9
0
    def testIssue758(self):
        # code from https://github.com/google/jax/issues/758
        # this is more of a scan + jacfwd/jacrev test, but it lives here to use the
        # optimizers.py code

        def harmonic_bond(conf, params):
            return np.sum(conf * params)

        opt_init, opt_update, get_params = optimizers.sgd(5e-2)

        x0 = onp.array([0.5], dtype=onp.float64)
        params = onp.array([0.3], dtype=onp.float64)

        def minimize_structure(test_params):
            energy_fn = functools.partial(harmonic_bond, params=test_params)
            grad_fn = grad(energy_fn, argnums=(0, ))
            opt_state = opt_init(x0)

            def apply_carry(carry, _):
                i, x = carry
                g = grad_fn(get_params(x))[0]
                new_state = opt_update(i, g, x)
                new_carry = (i + 1, new_state)
                return new_carry, _

            carry_final, _ = lax.scan(apply_carry, (0, opt_state),
                                      np.zeros((75, 0)))
            trip, opt_final = carry_final
            assert trip == 75
            return opt_final

        initial_params = np.float64(0.5)
        minimize_structure(initial_params)

        def loss(test_params):
            opt_final = minimize_structure(test_params)
            return 1.0 - get_params(opt_final)[0]

        loss_opt_init, loss_opt_update, loss_get_params = optimizers.sgd(5e-2)

        J1 = jacrev(loss, argnums=(0, ))(initial_params)
        J2 = jacfwd(loss, argnums=(0, ))(initial_params)
        self.assertAllClose(J1, J2, check_dtypes=True, rtol=1e-6)
Example #10
0
    def testMaxLearningRate(self, train_shape, network, out_logits,
                            fn_and_kernel, name):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        if len(train_shape) == 2:
            train_shape = (train_shape[0] * 5, train_shape[1] * 10)
        else:
            train_shape = (16, 8, 8, 3)
        x_train = random.normal(split, train_shape)

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        for lr_factor in [0.5, 3.]:
            params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                           out_logits)

            # Regress to an MSE loss.
            loss = lambda params, x: \
                0.5 * np.mean((f(params, x) - y_train) ** 2)
            grad_loss = jit(grad(loss))

            g_dd = ntk(x_train, None, 'ntk')

            steps = 20
            if name == 'theoretical':
                step_size = predict.max_learning_rate(
                    g_dd, num_outputs=out_logits) * lr_factor
            else:
                step_size = predict.max_learning_rate(
                    g_dd, num_outputs=-1) * lr_factor
            opt_init, opt_update, get_params = optimizers.sgd(step_size)
            opt_state = opt_init(params)

            def get_loss(opt_state):
                return loss(get_params(opt_state), x_train)

            init_loss = get_loss(opt_state)

            for i in range(steps):
                params = get_params(opt_state)
                opt_state = opt_update(i, grad_loss(params, x_train),
                                       opt_state)

            trained_loss = get_loss(opt_state)
            loss_ratio = trained_loss / (init_loss + 1e-12)
            if lr_factor == 3.:
                if not math.isnan(loss_ratio):
                    self.assertGreater(loss_ratio, 10.)
            else:
                self.assertLess(loss_ratio, 0.1)
  def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                 out_logits):
    training_steps = 1000
    learning_rate = 0.1
    ensemble_size = 1024

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(),
        stax.Dense(out_logits, W_std=1.2, b_std=0.05))

    opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
    opt_update = jit(opt_update)

    key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                     train_shape)
    predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
        kernel_fn,
        x_train,
        y_train,
        learning_rate=learning_rate,
        diag_reg=0.)

    train = (x_train, y_train)
    ensemble_key = random.split(key, ensemble_size)

    loss = jit(lambda params, x, y: 0.5 * np.mean((apply_fn(params, x) - y)**2))
    grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y))

    def train_network(key):
      _, params = init_fn(key, (-1,) + train_shape[1:])
      opt_state = opt_init(params)
      for i in range(training_steps):
        opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state)

      return get_params(opt_state)

    params = vmap(train_network)(ensemble_key)
    rtol = 0.08

    for x in [None, 'x_test']:
      with self.subTest(x=x):
        x = x if x is None else x_test
        x_fin = x_train if x is None else x_test
        ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin)

        mean_emp = np.mean(ensemble_fx, axis=0, keepdims=True)
        mean_subtracted = ensemble_fx - mean_emp
        cov_emp = np.einsum(
            'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (
                mean_subtracted.shape[0] * mean_subtracted.shape[-1])

        ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True)
        self._assertAllClose(mean_emp, ntk.mean, rtol)
        self._assertAllClose(cov_emp, ntk.covariance, rtol)
Example #12
0
def get_optimizer(optimizer, sched, b1=0.9, b2=0.999):
  if optimizer.lower() == 'adagrad':
    return optimizers.adagrad(sched)
  elif optimizer.lower() == 'adam':
    return optimizers.adam(sched, b1, b2)
  elif optimizer.lower() == 'rmsprop':
    return optimizers.rmsprop(sched)
  elif optimizer.lower() == 'momentum':
    return optimizers.momentum(sched, 0.9)
  elif optimizer.lower() == 'sgd':
    return optimizers.sgd(sched)
  else:
    raise Exception('Invalid optimizer: {}'.format(optimizer))
def main(unused_argv):
  # Build data pipelines.
  print('Loading data.')
  x_train, y_train, x_test, y_test = \
      datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

  # Build the network
  init_fn, apply_fn, _ = stax.serial(
      stax.Dense(512, 1., 0.05),
      stax.Erf(),
      stax.Dense(10, 1., 0.05))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Create and initialize an optimizer.
  opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
  state = opt_init(params)

  # Create an mse loss function and a gradient function.
  loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
  grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

  # Create an MSE predictor to solve the NTK equation in function space.
  ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0),
                 batch_size=4, device_count=0)
  g_dd = ntk(x_train, None, params)
  g_td = ntk(x_test, x_train, params)
  predictor = nt.predict.gradient_descent_mse(g_dd, y_train)

  # Get initial values of the network in function space.
  fx_train = apply_fn(params, x_train)
  fx_test = apply_fn(params, x_test)

  # Train the network.
  train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
  print('Training for {} steps'.format(train_steps))

  for i in range(train_steps):
    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x_train, y_train), state)

  # Get predictions from analytic computation.
  print('Computing analytic prediction.')
  fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td)

  # Print out summary data comparing the linear / nonlinear model.
  util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                     loss)
  util.print_summary('test', y_test, apply_fn(params, x_test), fx_test,
                     loss)
Example #14
0
def optimizer(name="adam",
              momentum_mass=0.9, rmsprop_gamma=0.9, rmsprop_eps=1e-8,
              adam_b1=0.9, adam_b2=0.997, adam_eps=1e-8):
  """Return the optimizer, by name."""
  if name == "sgd":
    return optimizers.sgd(learning_rate)
  if name == "momentum":
    return optimizers.momentum(learning_rate, mass=momentum_mass)
  if name == "rmsprop":
    return optimizers.rmsprop(
        learning_rate, gamma=rmsprop_gamma, eps=rmsprop_eps)
  if name == "adam":
    return optimizers.adam(learning_rate, b1=adam_b1, b2=adam_b2, eps=adam_eps)
  raise ValueError("Unknown optimizer %s" % str(name))
Example #15
0
  def testTracedStepSize(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    step_size = 0.1

    init_fun, _, _ = optimizers.sgd(step_size)
    opt_state = init_fun(x0)

    @jit
    def update(opt_state, step_size):
      _, update_fun, get_params = optimizers.sgd(step_size)
      x = get_params(opt_state)
      g = grad(loss)(x)
      return update_fun(0, g, opt_state)

    update(opt_state, 0.9)  # doesn't crash
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    theta = tangents.ntk(f, batch_size=32)
    g_dd = theta(params, x_train)
    import ipdb
    ipdb.set_trace()
    g_td = theta(params, x_test, x_train)
    predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td)

    # Get initial values of the network in function space.
    fx_train = f(params, x_train)
    fx_test = f(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
Example #17
0
  def testTracedStepSize(self):
    def loss(x, _): return np.dot(x, x)
    x0 = np.ones(2)
    num_iters = 100
    step_size = 0.1

    init_fun, _ = optimizers.sgd(step_size)
    opt_state = init_fun(x0)

    @jit
    def update(opt_state, step_size):
      _, update_fun = optimizers.sgd(step_size)
      x = optimizers.get_params(opt_state)
      g = grad(loss)(x, None)
      return update_fun(0, g, opt_state)

    update(opt_state, 0.9)  # doesn't crash
Example #18
0
def relaxed_gd(X, phi_y, W_0, lr, beta_, lambda_, tol, N_max, verbose=True):
    """Run gradient descent on relaxed_loss."""
    print(f"Lambda {lambda_}")
    dim_H_ = phi_y.shape[1]
    n_features_ = X.shape[1]
    predictor_shape = (
        dim_H_,
        n_features_,
    )
    assert W_0.shape == predictor_shape

    opt_init, opt_update, get_params = optimizers.sgd(lr)
    opt_state = opt_init(W_0)

    def step(opt_state, X, phi_y, beta, lambda_, i):
        value, grads = jax.value_and_grad(relaxed_predict_loss)(
            get_params(opt_state), X, phi_y, beta, lambda_)
        if (i < 10 or i % 100 == 0) and verbose:
            print(f"Step {i} value: \t\t{value.item()}")
        opt_state = opt_update(i, grads, opt_state)
        return value, opt_state

    old_value = jnp.inf
    diff = jnp.inf
    step_index = 0

    while step_index < N_max:
        value, opt_state = step(opt_state, X, phi_y, beta_, lambda_,
                                step_index)

        diff = abs(old_value - value)
        old_value = value
        step_index += 1

    status = {
        "max_steps": step_index >= N_max,
        "tol_max": tol >= diff.item(),
        "diff": diff.item(),
        "step_index": step_index,
        "final_value": value.item(),
    }
    if verbose:
        print(status)

    return get_params(opt_state), status
Example #19
0
def privately_train(rng, params, predict, X, y):
  """Generic train function called for each slice.

  Responsible for, given an rng key, a set of parameters to be trained, some inputs X and some outputs y,
  finetuning the params on X and y according to some internally defined training configuration.
  """
  locals().update(private_training_parameters)

  def clipped_grad(params, l2_norm_clip, single_example_batch):
    grads = grad(loss)(params, predict, single_example_batch)
    nonempty_grads, tree_def = tree_flatten(grads)
    total_grad_norm = np.linalg.norm([np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
    divisor = np.max((total_grad_norm / l2_norm_clip, 1.))
    normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
    return tree_unflatten(tree_def, normalized_nonempty_grads)

  def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size):
    # Add batch dimension for when each example is separated
    batch = (np.expand_dims(batch[0], 1), np.expand_dims(batch[1], 1))
    clipped_grads = vmap(clipped_grad, (None, None, 0))(params, l2_norm_clip, batch)
    clipped_grads_flat, grads_treedef = tree_flatten(clipped_grads)
    aggregated_clipped_grads = [np.sum(g, 0) for g in clipped_grads_flat]
    rngs = random.split(rng, len(aggregated_clipped_grads))
    noised_aggregated_clipped_grads = [g + l2_norm_clip * noise_multiplier * random.normal(r, g.shape) for r, g in zip(rngs, aggregated_clipped_grads)]
    normalized_noised_aggregated_clipped_grads = [g / batch_size for g in noised_aggregated_clipped_grads]
    return tree_unflatten(grads_treedef, normalized_noised_aggregated_clipped_grads)

  @jit
  def private_update(rng, i, opt_state, batch):
    params = get_params(opt_state)
    rng = random.fold_in(rng, i)
    grads = private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size)
    return opt_update(i, grads, opt_state)

  opt_init, opt_update, get_params = optimizers.sgd(step_size)
  opt_state = opt_init(params)

  batches = data_stream(rng, batch_size, X, y)
  itercount = itertools.count()
  for _ in range(iterations):
    temp, rng = random.split(rng)
    opt_state = private_update(temp, next(itercount), opt_state, next(batches))

  return get_params(opt_state)
Example #20
0
    def run_optimiser(Niters, l_rate, x_data, y_data, params_IC):

        if optimiser_type == "sgd":
            opt_init, opt_update, get_params = optimizers.sgd(l_rate)
        elif optimiser_type == "adam":
            opt_init, opt_update, get_params = optimizers.adam(l_rate)
        else:
            raise ValueError("Optimiser not added.")

        @progress_bar_scan(Niters)
        def body(state, step):
            loss_val, loss_grad = val_and_grad_loss(get_params(state), x_data,
                                                    y_data)
            state = opt_update(step, loss_grad, state)
            return state, loss_val

        state, loss_array = lax.scan(body, opt_init(params_IC),
                                     jnp.arange(Niters))
        return get_params(state), loss_array
Example #21
0
    def test_sgd(self):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = optimizers.sgd(LR)
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # experimental/optix.py
        optix_params = self.init_params
        sgd = optix.sgd(LR, 0.0)
        state = sgd.init(optix_params)
        for _ in range(STEPS):
            updates, state = sgd.update(self.per_step_updates, state)
            optix_params = optix.apply_updates(optix_params, updates)

        # Check equivalence.
        for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)):
            np.testing.assert_allclose(x, y, rtol=1e-5)
Example #22
0
    def get_optimizer(self, optim=None, stage='learn', step_size=None):

        if optim is None:
            if stage == 'learn':
                optim = self.optim_learn
            else:
                optim = self.optim_proj
        if step_size is None:
            step_size = self.step_size

        if optim == 1:
            if self.verb > 2:
                print("With momentum optimizer")
            opt_init, opt_update, get_params = momentum(step_size=step_size,
                                                        mass=0.95)
        elif optim == 2:
            if self.verb > 2:
                print("With rmsprop optimizer")
            opt_init, opt_update, get_params = rmsprop(step_size,
                                                       gamma=0.9,
                                                       eps=1e-8)
        elif optim == 3:
            if self.verb > 2:
                print("With adagrad optimizer")
            opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9)
        elif optim == 4:
            if self.verb > 2:
                print("With Nesterov optimizer")
            opt_init, opt_update, get_params = nesterov(step_size, 0.9)
        elif optim == 5:
            if self.verb > 2:
                print("With SGD optimizer")
            opt_init, opt_update, get_params = sgd(step_size)
        else:
            if self.verb > 2:
                print("With adam optimizer")
            opt_init, opt_update, get_params = adam(step_size)

        return opt_init, opt_update, get_params
Example #23
0
def gamma2kappa(g1, g2, kappa_shape, obj, step_size=0.01, n_iter=1000):

    # Set up optimizer
    init_kE = np.zeros(kappa_shape)
    init_kB = np.zeros(kappa_shape)
    init_params = (init_kE, init_kB)
    opt_init, opt_update, get_params = optimizers.sgd(step_size=0.001)
    opt_state = opt_init(init_params)

    @jit
    def update(i, opt_state):
        params = get_params(opt_state)
        gradient = grad(obj)(params, g1, g2)
        return opt_update(i, gradient, opt_state)

    # Loop
    for t in range(n_iter):
        opt_state = update(t, opt_state)
        params = get_params(opt_state)

    kEhat, kBhat = params
    return kEhat, kBhat
Example #24
0
def get_optimizer(
    learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None
) -> JaxOptimizer:
    """Return a `JaxOptimizer` dataclass for a JAX optimizer

    Args:
        learning_rate (float, optional): Step size. Defaults to 1e-4.
        optimizer (str, optional): Optimizer type (Allowed types: "adam",
            "adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg".
        optimizer_kwargs (dict, optional): Additional keyword arguments
            that are passed to the optimizer. Defaults to None.

    Returns:
        JaxOptimizer
    """
    from jax.config import config  # pylint:disable=import-outside-toplevel

    config.update("jax_enable_x64", True)
    from jax import jit  # pylint:disable=import-outside-toplevel
    from jax.experimental import optimizers  # pylint:disable=import-outside-toplevel

    if optimizer_kwargs is None:
        optimizer_kwargs = {}
    optimizer = optimizer.lower()
    if optimizer == "adam":
        opt_init, opt_update, get_params = optimizers.adam(learning_rate, **optimizer_kwargs)
    elif optimizer == "adagrad":
        opt_init, opt_update, get_params = optimizers.adagrad(learning_rate, **optimizer_kwargs)
    elif optimizer == "adamax":
        opt_init, opt_update, get_params = optimizers.adamax(learning_rate, **optimizer_kwargs)
    elif optimizer == "rmsprop":
        opt_init, opt_update, get_params = optimizers.rmsprop(learning_rate, **optimizer_kwargs)
    else:
        opt_init, opt_update, get_params = optimizers.sgd(learning_rate, **optimizer_kwargs)

    opt_update = jit(opt_update)

    return JaxOptimizer(opt_init, opt_update, get_params)
Example #25
0
def run_vmc(steps, step_size, diag_shift, n_samples):
    opt = jaxopt.sgd(step_size)
    # opt = nk.optimizer.Sgd(step_size)

    sr = nk.optimizer.SR(lsq_solver="BDCSVD", diag_shift=diag_shift)
    sr.store_rank_enabled = False  # not supported by BDCSVD
    sr.store_covariance_matrix_enabled = True

    vmc = nk.Vmc(
        hamiltonian=ha,
        sampler=sa,
        optimizer=opt,
        n_samples=n_samples,
        n_discard=min(n_samples // 10, 200),
        sr=sr,
    )

    if mpi_rank == 0:
        print(vmc.info())
        print(HEADER_STRING)

    for step in vmc.iter(steps, 1):
        output(vmc, step)
Example #26
0
def test_validate_optimizers():
    """Make sure we correctly check/standardize the optimizers."""
    from jax import jit  # pylint:disable=import-outside-toplevel
    from jax.experimental import optimizers  # pylint:disable=import-outside-toplevel

    from pyepal.models.nt import JaxOptimizer  # pylint:disable=import-outside-toplevel

    opt_init, opt_update, get_params = optimizers.sgd(1e-3)
    opt_update = jit(opt_update)

    optimizer = JaxOptimizer(opt_init, opt_update, get_params)
    optimizers = [
        JaxOptimizer(opt_init, opt_update, get_params),
        JaxOptimizer(opt_init, opt_update, get_params),
    ]
    with pytest.raises(ValueError):
        validate_optimizers(opt_init, 2)

    with pytest.raises(ValueError):
        validate_optimizers([optimizer], 2)

    assert validate_optimizers(optimizers, 2) == optimizers
    with pytest.raises(ValueError):
        validate_optimizers(optimizer, 2)
Example #27
0
 def update(opt_state, step_size):
     _, update_fun, get_params = optimizers.sgd(step_size)
     x = get_params(opt_state)
     g = grad(loss)(x)
     return update_fun(0, g, opt_state)
st.pyplot()
plt.close()

st.header("Comparison against finite SGD-NNs")
"""
Finally, let's bring back the practioner loved Neural Networks for a comparison. 
"""

learning_rate = st.slider("Learning rate",
                          1e-4,
                          1.0,
                          0.1,
                          step=1e-4,
                          format="%.4f")

opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_update = jit(opt_update)
loss = jit(lambda params, x, y: 0.5 * np.mean((apply_fn(params, x) - y)**2))
grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y))

train_losses = []
test_losses = []

opt_state = opt_init(params)

for i in range(training_steps):
    opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state)

    train_losses += [loss(get_params(opt_state), *train)]
    test_losses += [loss(get_params(opt_state), *test)]
Example #29
0
def main(_):
    sns.set()
    sns.set_palette(sns.color_palette('hls', 10))
    npr.seed(FLAGS.seed)

    logging.info('Starting experiment.')

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.work_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+')

    # use mean/std of svhn train
    train_images, _, _ = datasets.get_dataset_split(
        name=FLAGS.train_split.split('-')[0],
        split=FLAGS.train_split.split('-')[1],
        shuffle=False)
    train_mu, train_std = onp.mean(train_images), onp.std(train_images)
    del train_images

    # BEGIN: fetch test data and candidate pool
    test_images, test_labels, _ = datasets.get_dataset_split(
        name=FLAGS.test_split.split('-')[0],
        split=FLAGS.test_split.split('-')[1],
        shuffle=False)
    pool_images, pool_labels, _ = datasets.get_dataset_split(
        name=FLAGS.pool_split.split('-')[0],
        split=FLAGS.pool_split.split('-')[1],
        shuffle=False)

    n_pool = len(pool_images)
    test_images = (test_images -
                   train_mu) / train_std  # normalize w train mu/std
    pool_images = (pool_images -
                   train_mu) / train_std  # normalize w train mu/std

    # augmentation for train/pool data
    if FLAGS.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    # END: fetch test data and candidate pool

    # BEGIN: load ckpt
    opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

    if FLAGS.pretrained_dir is not None:
        with gfile.Open(FLAGS.pretrained_dir, 'rb') as fpre:
            pretrained_opt_state = optimizers.pack_optimizer_state(
                pickle.load(fpre))
        fixed_params = get_params(pretrained_opt_state)[:7]

        ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
        with gfile.Open(ckpt_dir, 'wr') as fckpt:
            opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
        params = get_params(opt_state)
        # combine fixed pretrained params and dpsgd trained last layers
        params = fixed_params + params
        opt_state = opt_init(params)
    else:
        ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
        with gfile.Open(ckpt_dir, 'wr') as fckpt:
            opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
        params = get_params(opt_state)

    stdout_log.write('finetune from: {}\n'.format(ckpt_dir))
    logging.info('finetune from: %s', ckpt_dir)
    test_acc, test_pred = accuracy(params,
                                   shape_as_image(test_images, test_labels),
                                   return_predicted_class=True)
    logging.info('test accuracy: %.2f', test_acc)
    stdout_log.write('test accuracy: {}\n'.format(test_acc))
    stdout_log.flush()
    # END: load ckpt

    # BEGIN: setup for dp model
    @jit
    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad_loss(params, batch), opt_state)

    @jit
    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                         FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

    # END: setup for dp model

    ### BEGIN: prepare extra points picked from pool data
    # BEGIN: on pool data
    pool_embeddings = [apply_fn_0(params[:-1],
                                  pool_images[b_i:b_i + FLAGS.batch_size]) \
                       for b_i in range(0, n_pool, FLAGS.batch_size)]
    pool_embeddings = np.concatenate(pool_embeddings, axis=0)

    pool_logits = apply_fn_1(params[-1:], pool_embeddings)

    pool_true_labels = np.argmax(pool_labels, axis=1)
    pool_predicted_labels = np.argmax(pool_logits, axis=1)
    pool_correct_indices = \
        onp.where(pool_true_labels == pool_predicted_labels)[0]
    pool_incorrect_indices = \
        onp.where(pool_true_labels != pool_predicted_labels)[0]
    assert len(pool_correct_indices) + \
        len(pool_incorrect_indices) == len(pool_labels)

    pool_probs = stax.softmax(pool_logits)

    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1)
        stdout_log.write('all {} entropy: min {}, max {}\n'.format(
            len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy)))

        pool_entropy_sorted_indices = onp.argsort(pool_entropy)
        # take the n_extra most uncertain points
        pool_uncertain_indices = \
            pool_entropy_sorted_indices[::-1][:FLAGS.n_extra]
        stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format(
            len(pool_entropy[pool_uncertain_indices]),
            onp.min(pool_entropy[pool_uncertain_indices]),
            onp.max(pool_entropy[pool_uncertain_indices])))

    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(pool_probs.shape) == 2
        sorted_pool_probs = onp.sort(pool_probs, axis=1)
        pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2]
        assert min(pool_probs_diff) > 0.
        stdout_log.write('all {} difference: min {}, max {}\n'.format(
            len(pool_probs_diff), onp.min(pool_probs_diff),
            onp.max(pool_probs_diff)))

        pool_uncertain_indices = onp.argsort(pool_probs_diff)[:FLAGS.n_extra]
        stdout_log.write('uncertain {} difference: min {}, max {}\n'.format(
            len(pool_probs_diff[pool_uncertain_indices]),
            onp.min(pool_probs_diff[pool_uncertain_indices]),
            onp.max(pool_probs_diff[pool_uncertain_indices])))

    elif FLAGS.uncertain == 2 or FLAGS.uncertain == 'random':
        pool_uncertain_indices = npr.permutation(n_pool)[:FLAGS.n_extra]

    # END: on pool data
    ### END: prepare extra points picked from pool data

    finetune_images = copy.deepcopy(pool_images[pool_uncertain_indices])
    finetune_labels = copy.deepcopy(pool_labels[pool_uncertain_indices])

    stdout_log.write('Starting fine-tuning...\n')
    logging.info('Starting fine-tuning...')
    stdout_log.flush()

    stdout_log.write('{} points picked via {}\n'.format(
        len(finetune_images), FLAGS.uncertain))
    logging.info('%d points picked via %s', len(finetune_images),
                 FLAGS.uncertain)
    assert FLAGS.n_extra == len(finetune_images)

    for epoch in range(1, FLAGS.epochs + 1):

        # BEGIN: finetune model with extra data, evaluate and save
        num_extra = len(finetune_images)
        num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size)
        num_batches = num_complete_batches + bool(leftover)

        finetune = data.DataChunk(X=finetune_images,
                                  Y=finetune_labels,
                                  image_size=32,
                                  image_channels=3,
                                  label_dim=1,
                                  label_format='numeric')

        batches = data.minibatcher(finetune,
                                   FLAGS.batch_size,
                                   transform=augmentation)

        itercount = itertools.count()
        key = random.PRNGKey(FLAGS.seed)

        start_time = time.time()

        for _ in range(num_batches):
            # tmp_time = time.time()
            b = next(batches)
            if FLAGS.dpsgd:
                opt_state = private_update(
                    key, next(itercount), opt_state,
                    shape_as_image(b.X, b.Y, dummy_dim=True))
            else:
                opt_state = update(key, next(itercount), opt_state,
                                   shape_as_image(b.X, b.Y))
            # stdout_log.write('single update in {:.2f} sec\n'.format(
            #     time.time() - tmp_time))

        epoch_time = time.time() - start_time
        stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time))
        logging.info('Epoch %d in %.2f sec', epoch, epoch_time)

        # accuracy on test data
        params = get_params(opt_state)

        test_pred_0 = test_pred
        test_acc, test_pred = accuracy(params,
                                       shape_as_image(test_images,
                                                      test_labels),
                                       return_predicted_class=True)
        test_loss = loss(params, shape_as_image(test_images, test_labels))
        stdout_log.write(
            'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format(
                test_loss, 100 * test_acc))
        logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss,
                     100 * test_acc)
        stdout_log.flush()

        # visualize prediction difference between 2 checkpoints.
        if FLAGS.visualize:
            utils.visualize_ckpt_difference(test_images,
                                            np.argmax(test_labels, axis=1),
                                            test_pred_0,
                                            test_pred,
                                            epoch - 1,
                                            epoch,
                                            FLAGS.work_dir,
                                            mu=train_mu,
                                            sigma=train_std)

    # END: finetune model with extra data, evaluate and save

    stdout_log.close()
def main(_):

  if FLAGS.microbatches:
    raise NotImplementedError(
        'Microbatches < batch size not currently supported'
    )

  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size)
  num_batches = num_complete_batches + bool(leftover)
  key = random.PRNGKey(FLAGS.seed)

  def data_stream():
    rng = npr.RandomState(FLAGS.seed)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]

  batches = data_stream()

  opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

  @jit
  def update(_, i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  @jit
  def private_update(rng, i, opt_state, batch):
    params = get_params(opt_state)
    rng = random.fold_in(rng, i)  # get new key for new random numbers
    return opt_update(
        i,
        private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                     FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

  _, init_params = init_random_params(key, (-1, 28, 28, 1))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  steps_per_epoch = 60000 // FLAGS.batch_size
  print('\nStarting training...')
  for epoch in range(1, FLAGS.epochs + 1):
    start_time = time.time()
    # pylint: disable=no-value-for-parameter
    for _ in range(num_batches):
      if FLAGS.dpsgd:
        opt_state = \
            private_update(
                key, next(itercount), opt_state,
                shape_as_image(*next(batches), dummy_dim=True))
      else:
        opt_state = update(
            key, next(itercount), opt_state, shape_as_image(*next(batches)))
    # pylint: enable=no-value-for-parameter
    epoch_time = time.time() - start_time
    print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time))

    # evaluate test accuracy
    params = get_params(opt_state)
    test_acc = accuracy(params, shape_as_image(test_images, test_labels))
    test_loss = loss(params, shape_as_image(test_images, test_labels))
    print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format(
        test_loss, 100 * test_acc))

    # determine privacy loss so far
    if FLAGS.dpsgd:
      delta = 1e-5
      num_examples = 60000
      eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)
      print(
          'For delta={:.0e}, the current epsilon is: {:.2f}'.format(delta, eps))
    else:
      print('Trained with vanilla non-private SGD optimizer')