Example #1
0
  def test_update_requires_params(self):
    weight_decay = 0.1
    mask = {'a': True,
            'b': [False, True],
            'c': {'d': True, 'e': (False, True)}}
    params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}}
    input_updates = jax.tree_util.tree_map(lambda x: x/10., params)

    correct_updates = jax.tree_util.tree_multimap(
        lambda m, u, p: u + weight_decay * p if m else u,
        mask, input_updates, params)

    init_fn, update_fn = wrappers.masked(
        transform.additive_weight_decay(weight_decay), mask)
    update_fn = self.variant(update_fn)

    state = init_fn(params)
    updates, state = update_fn(input_updates, state, params)
    chex.assert_tree_all_close(updates, correct_updates)

    params = update.apply_updates(params, updates)

    # Test repeated application
    new_correct_updates = jax.tree_util.tree_multimap(
        lambda m, u, p: u + weight_decay * p if m else u,
        mask, correct_updates, params)
    updates, state = update_fn(correct_updates, state, params)
    chex.assert_tree_all_close(updates, new_correct_updates)
Example #2
0
def adamw(learning_rate: float,
          b1: float = 0.9,
          b2: float = 0.999,
          eps: float = 1e-8,
          weight_decay: float = 1e-4) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps),
        transform.additive_weight_decay(weight_decay),
        transform.scale(-learning_rate),
    )
Example #3
0
def adamw(learning_rate: ScalarOrSchedule,
          b1: float = 0.9,
          b2: float = 0.999,
          eps: float = 1e-8,
          eps_root: float = 0.0,
          weight_decay: float = 1e-4) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        transform.additive_weight_decay(weight_decay),
        _scale_by_learning_rate(learning_rate),
    )
Example #4
0
def lamb(learning_rate: float,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-6,
         eps_root: float = 0.0,
         weight_decay: float = 0.) -> GradientTransformation:
  return combine.chain(
      transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
      transform.additive_weight_decay(weight_decay),
      transform.scale_by_trust_ratio(),
      transform.scale(-learning_rate),
  )
Example #5
0
def adamw(
    learning_rate: ScalarOrSchedule,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    weight_decay: float = 1e-4
) -> base.GradientTransformation:
  """Adam with weight decay regularization.

  AdamW uses weight decay to regularise learning towards small weights, as
  this leads to better generalisation. In SGD you can also use L2 regularisation
  to implement this as an additive loss term, however L2 regularization
  does not behave as intended for adaptive gradient algorithms such as Adam.

  WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or
  for the bias parameters. You can use `optax.masked` to make your own AdamW
  variant where `additive_weight_decay` is applied only to a subset of `params`.

  References:
    Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: (default `0`), a small constant applied to denominator inside the
      square root (as in RMSProp), to avoid dividing by zero when rescaling.
      This is needed for instance when computing (meta-)gradients through Adam.
    weight_decay: strength of the weight decay regularization.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
      transform.additive_weight_decay(weight_decay),
      _scale_by_learning_rate(learning_rate),
  )