Exemple #1
0
def sgd(learning_rate: float,
        momentum: float = 0.,
        nesterov: bool = False) -> GradientTransformation:
    return combine.chain(
        transform.trace(decay=momentum, nesterov=nesterov),
        transform.scale(-learning_rate),
    )
Exemple #2
0
  def test_chain(self):
    transformations = [
        transform.scale_by_adam(),
        transform.trace(decay=0, nesterov=False),
        transform.scale(-LR)]

    # Apply updates with chain.
    chain_params = self.init_params
    chained_transforms = combine.chain(*transformations)
    state = chained_transforms.init(chain_params)

    @self.variant
    def update_fn(updates, state):
      return chained_transforms.update(updates, state)

    for _ in range(STEPS):
      updates, state = update_fn(self.per_step_updates, state)
      chain_params = update.apply_updates(chain_params, updates)

    # Manually apply sequence of transformations.
    manual_params = self.init_params
    states = [t.init(manual_params) for t in transformations]
    for _ in range(STEPS):
      updates = self.per_step_updates
      new_states = []
      for t, s in zip(transformations, states):
        updates, state = t.update(updates, s)
        new_states.append(state)
      manual_params = update.apply_updates(manual_params, updates)
      states = new_states

    # Check equivalence.
    chex.assert_tree_all_close(manual_params, chain_params, rtol=1e-4)
Exemple #3
0
def noisy_sgd(learning_rate: float,
              eta: float = 0.01,
              gamma: float = 0.55,
              seed: int = 0) -> GradientTransformation:
    return combine.chain(
        transform.trace(decay=0., nesterov=False),
        transform.scale(-learning_rate),
        transform.add_noise(eta, gamma, seed),
    )
Exemple #4
0
def lars(
    learning_rate: ScalarOrSchedule,
    weight_decay: float = 0.,
    weight_decay_mask: MaskOrFn = True,
    trust_coefficient: float = 0.001,
    eps: float = 0.,
    trust_ratio_mask: MaskOrFn = True,
    momentum: float = 0.9,
    nesterov: bool = False,
) -> base.GradientTransformation:
  """The LARS optimiser.

  LAMB is a layer-wise adaptive optimiser introduced to help scale SGD to
  larger batch sizes. LARS later inspired the LAMB optimiser.

  References:
    You et al, 2017: https://arxiv.org/abs/1708.03888

  Args:
    learning_rate: this is a fixed global scaling factor.
    weight_decay (default `0.`): strength of the weight decay regularization.
    weight_decay_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    trust_coefficient: a multiplier for the trust ratio.
    eps: optional additive constant in the trust ratio denominator.
    trust_ratio_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    momentum: the decay rate for momentum.
    nesterov: whether to use Nesterov momentum.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
      wrappers.masked(
          inner=transform.scale_by_trust_ratio(
              trust_coefficient=trust_coefficient, eps=eps),
          mask=trust_ratio_mask),
      _scale_by_learning_rate(learning_rate),
      transform.trace(decay=momentum, nesterov=nesterov),
  )
Exemple #5
0
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # optax sgd
        optax_sgd_params = self.init_params
        sgd = alias.sgd(LR, 0.0)
        state_sgd = sgd.init(optax_sgd_params)

        # optax sgd plus apply every
        optax_sgd_apply_every_params = self.init_params
        sgd_apply_every = combine.chain(
            transform.apply_every(k=k), transform.trace(decay=0,
                                                        nesterov=False),
            transform.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optax_sgd_apply_every_params)
        transform_fn = self.variant(sgd_apply_every.update)

        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optax_sgd_params = update.apply_updates(optax_sgd_params,
                                                    updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = transform_fn(
                self.per_step_updates, state_sgd_apply_every)
            optax_sgd_apply_every_params = update.apply_updates(
                optax_sgd_apply_every_params, updates_sgd_apply_every)

            # Every k steps, check equivalence.
            if i % k == k - 1:
                chex.assert_tree_all_close(optax_sgd_apply_every_params,
                                           optax_sgd_params,
                                           atol=1e-6,
                                           rtol=1e-5)
            # Otherwise, check update is zero.
            else:
                chex.assert_tree_all_close(updates_sgd_apply_every,
                                           zero_update,
                                           atol=0.0,
                                           rtol=0.0)
Exemple #6
0
def dpsgd(
    learning_rate: ScalarOrSchedule,
    l2_norm_clip: float,
    noise_multiplier: float,
    seed: int,
    momentum: Optional[float] = None,
    nesterov: bool = False
) -> base.GradientTransformation:
  """The DPSGD optimiser.

  Differential privacy is a standard for privacy guarantees of algorithms
  learning from aggregate databases including potentially sensitive information.
  DPSGD offers protection against a strong adversary with full knowledge of the
  training mechanism and access to the model’s parameters.

  WARNING: This `GradientTransformation` expects input updates to have a batch
  dimension on the 0th axis. That is, this function expects per-example
  gradients as input (which are easy to obtain in JAX using `jax.vmap`).

  References:
    Abadi et al, 2016: https://arxiv.org/abs/1607.00133

  Args:
    learning_rate: this is a fixed global scaling factor.
    l2_norm_clip: maximum L2 norm of the per-example gradients.
    noise_multiplier: ratio of standard deviation to the clipping norm.
    seed: initial seed used for the jax.random.PRNGKey
    momentum: (default `None`), the `decay` rate used by the momentum term,
      when it is set to `None`, then momentum is not used at all.
    nesterov (default `False`): whether nesterov momentum is used.

  Returns:
    A `GradientTransformation`.
  """
  return combine.chain(
      privacy.differentially_private_aggregate(
          l2_norm_clip=l2_norm_clip,
          noise_multiplier=noise_multiplier,
          seed=seed),
      (transform.trace(decay=momentum, nesterov=nesterov)
       if momentum is not None else base.identity()),
      _scale_by_learning_rate(learning_rate)
  )
Exemple #7
0
def sgd(learning_rate: ScalarOrSchedule,
        momentum: Optional[float] = None,
        nesterov: bool = False) -> base.GradientTransformation:
    """A canonical Stochastic Gradient Descent optimiser.

  This implements stochastic gradient descent. It also includes support for
  momentum, and nesterov acceleration, as these are standard practice when
  using stochastic gradient descent to train deep neural networks.

  References:
    [Sutskever et al, 2013](http://proceedings.mlr.press/v28/sutskever13.pdf)

  Args:
    learning_rate: this is a fixed global scaling factor.
    momentum: (default `None`), the `decay` rate used by the momentum term,
      when it is set to `None`, then momentum is not used at all.
    nesterov (default `False`): whether nesterov momentum is used.

  Returns:
    A `GradientTransformation`.
  """
    return combine.chain((transform.trace(decay=momentum, nesterov=nesterov)
                          if momentum is not None else base.identity()),
                         _scale_by_learning_rate(learning_rate))
Exemple #8
0
    def test_labels_mismatch(self, use_extra_label, use_fn):
        # The labels from label_fn must be a subet of the keys for the tx.
        params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}}
        params = jax.tree_map(jnp.asarray, params)
        label_tree = {'a': 0, 'b': [1, 0], 'c': 1}  # prefix of params

        if use_extra_label:
            label_tree['a'] = 3

        transforms = {
            0: alias.sgd(1.),
            1: alias.adam(1., b1=0., b2=0.),
            2: transform.trace(1.0)
        }
        init_fn, update_fn = combine.multi_transform(
            transforms, (lambda _: label_tree) if use_fn else label_tree)

        if use_extra_label:
            with self.assertRaises(ValueError):
                self.variant(init_fn)(params)
        else:
            state = self.variant(init_fn)(params)
            updates = jax.tree_map(lambda x: x / 10.0, params)
            self.variant(update_fn)(updates, state)