Example #1
0
 def test_add_decayed_weights(self):
     # Define a transform that add decayed weights.
     # We can define a mask either as a pytree, or as a function that
     # returns the pytree. Below we define the pytree directly.
     mask = (True, dict(a=True, b=False))
     tx = transform.add_decayed_weights(0.1, mask=mask)
     # Define input updates and weights.
     updates = (jnp.zeros((2, ), dtype=jnp.float32),
                dict(
                    a=jnp.zeros((2, ), dtype=jnp.float32),
                    b=jnp.zeros((2, ), dtype=jnp.float32),
                ))
     weights = (jnp.ones((2, ), dtype=jnp.float32),
                dict(
                    a=jnp.ones((2, ), dtype=jnp.float32),
                    b=jnp.ones((2, ), dtype=jnp.float32),
                ))
     # This mask means that we will add decayed weights to the first two
     # terms in the input updates, but not to the last element.
     expected_tx_updates = (0.1 * jnp.ones((2, ), dtype=jnp.float32),
                            dict(
                                a=0.1 * jnp.ones((2, ), dtype=jnp.float32),
                                b=jnp.zeros((2, ), dtype=jnp.float32),
                            ))
     # Apply transform
     state = tx.init(weights)
     transform_fn = self.variant(tx.update)
     new_updates, _ = transform_fn(updates, state, weights)
     # Assert output as expected.
     chex.assert_tree_all_close(new_updates, expected_tx_updates)
Example #2
0
def fromage(learning_rate: float,
            min_norm: float = 1e-6) -> base.GradientTransformation:
    """The Frobenius matched gradient descent (Fromage) optimiser.

  Fromage is a learning algorithm that does not require learning rate tuning.
  The optimiser is based on modelling neural network gradients via deep relative
  trust (a distance function on deep neural networks). Fromage is similar to the
  LARS optimiser and can work on a range of standard neural network benchmarks,
  such as natural language Transformers and generative adversarial networks.

  References:
    Bernstein et al, 2020: https://arxiv.org/abs/2002.03432

  Args:
    learning_rate: this is a fixed global scaling factor.
    min_norm: a minimum value that the norm of the gradient updates and the
    norm of the layer parameters can be clipped to to avoid dividing by zero
    when computing the trust ratio (as in the LARS paper).

  Returns:
    the corresponding `GradientTransformation`.
  """
    mult = 1 / jnp.sqrt(1 + learning_rate**2)
    return combine.chain(
        transform.scale_by_trust_ratio(min_norm),
        _scale_by_learning_rate(learning_rate * mult),
        transform.add_decayed_weights((mult - 1)),
    )
Example #3
0
def fromage(learning_rate: float,
            min_norm: float = 1e-6) -> base.GradientTransformation:
    mult = 1 / jnp.sqrt(1 + learning_rate**2)
    return combine.chain(
        transform.scale_by_trust_ratio(min_norm),
        _scale_by_learning_rate(learning_rate * mult),
        transform.add_decayed_weights((mult - 1)),
    )
Example #4
0
def lamb(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-6,
         eps_root: float = 0.0,
         weight_decay: float = 0.) -> base.GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        transform.add_decayed_weights(weight_decay),
        transform.scale_by_trust_ratio(),
        _scale_by_learning_rate(learning_rate),
    )
Example #5
0
def adamw(
    learning_rate: ScalarOrSchedule,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype: Optional[Any] = None,
    weight_decay: float = 1e-4,
    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
  """Adam with weight decay regularization.

  AdamW uses weight decay to regularise learning towards small weights, as
  this leads to better generalisation. In SGD you can also use L2 regularisation
  to implement this as an additive loss term, however L2 regularization
  does not behave as intended for adaptive gradient algorithms such as Adam.

  WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or
  for the bias parameters. You can use `optax.masked` to make your own AdamW
  variant where `additive_weight_decay` is applied only to a subset of `params`.

  References:
    Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: (default `0`), a small constant applied to denominator inside the
      square root (as in RMSProp), to avoid dividing by zero when rescaling.
      This is needed for instance when computing (meta-)gradients through Adam.
    mu_dtype: optional `dtype` to be used for the first order accumulator; if
      `None` then the `dtype` is inferred from `params` and `updates`.
    weight_decay: strength of the weight decay regularization.
    mask: a tree with same structure as (or a prefix of) the params PyTree,
      or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the weight decay to, and `False` for those you want to skip. Note
      that the Adam gradient transformations are applied to all parameters.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.scale_by_adam(
          b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype),
      transform.add_decayed_weights(weight_decay, mask),
      _scale_by_learning_rate(learning_rate),
  )
Example #6
0
def lars(
    learning_rate: ScalarOrSchedule,
    weight_decay: float = 0.,
    weight_decay_mask: MaskOrFn = True,
    trust_coefficient: float = 0.001,
    eps: float = 0.,
    trust_ratio_mask: MaskOrFn = True,
    momentum: float = 0.9,
    nesterov: bool = False,
) -> base.GradientTransformation:
  """The LARS optimiser.

  LAMB is a layer-wise adaptive optimiser introduced to help scale SGD to
  larger batch sizes. LARS later inspired the LAMB optimiser.

  References:
    You et al, 2017: https://arxiv.org/abs/1708.03888

  Args:
    learning_rate: this is a fixed global scaling factor.
    weight_decay (default `0.`): strength of the weight decay regularization.
    weight_decay_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    trust_coefficient: a multiplier for the trust ratio.
    eps: optional additive constant in the trust ratio denominator.
    trust_ratio_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    momentum: the decay rate for momentum.
    nesterov: whether to use Nesterov momentum.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
      wrappers.masked(
          inner=transform.scale_by_trust_ratio(
              trust_coefficient=trust_coefficient, eps=eps),
          mask=trust_ratio_mask),
      _scale_by_learning_rate(learning_rate),
      transform.trace(decay=momentum, nesterov=nesterov),
  )
Example #7
0
    def test_mask_fn(self):
        params = {
            'a': jnp.ones((1, 2)),
            'b': (jnp.ones((1, )), np.ones((1, 2, 3)))
        }
        mask_fn = lambda p: jax.tree_map(lambda x: x.ndim > 1, p)
        init_fn, update_fn = wrappers.masked(
            transform.add_decayed_weights(0.1), mask_fn)
        update_fn = self.variant(update_fn)

        state = self.variant(init_fn)(params)
        grads = jax.tree_map(lambda x: x * 2, params)
        updates, state = update_fn(grads, state, params)
        np.testing.assert_allclose(updates['a'],
                                   grads['a'] + 0.1 * params['a'])
        np.testing.assert_allclose(updates['b'][0], grads['b'][0])
        np.testing.assert_allclose(updates['b'][1],
                                   grads['b'][1] + 0.1 * params['b'][1])
Example #8
0
def lamb(
    learning_rate: ScalarOrSchedule,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-6,
    eps_root: float = 0.0,
    weight_decay: float = 0.,
    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
    """The LAMB optimiser.

  LAMB is a general purpose layer-wise adaptive large batch optimiser designed
  to provide consistent training performance across a wide range of tasks,
  including those that use attention-based models (such as Transformers) and
  ResNet-50. The optimiser is able to work with small and large batch sizes.
  LAMB was inspired by the LARS learning algorithm.

  References:
    You et al, 2019: https://arxiv.org/abs/1904.00962

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: (default `0.0`), a small constant applied to denominator inside
      the square root (as in RMSProp), to avoid dividing by zero when rescaling.
      This is needed for instance when computing (meta-)gradients through Adam.
    weight_decay (default `0.`): strength of the weight decay regularization.
    mask: a tree with same structure as (or a prefix of) the params PyTree,
      or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        transform.add_decayed_weights(weight_decay=weight_decay, mask=mask),
        transform.scale_by_trust_ratio(),
        _scale_by_learning_rate(learning_rate),
    )
Example #9
0
def lamb(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-6,
         eps_root: float = 0.0,
         weight_decay: float = 0.) -> base.GradientTransformation:
    """The LAMB optimiser.

  LAMB is a general purpose layer-wise adaptive large batch optimiser designed
  to provide consistent training performance across a wide range of tasks,
  including those that use attention-based models (such as Transformers) and
  ResNet-50. The optimiser is able to work with small and large batch sizes.
  LAMB was inspired by the LARS learning algorithm.

  References:
    [You et al, 2019](https://arxiv.org/abs/1904.00962)

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: (default `0.0`), a small constant applied to denominator inside
      the square root (as in RMSProp), to avoid dividing by zero when rescaling.
      This is needed for instance when computing (meta-)gradients through Adam.
    weight_decay (default `0.`): strength of the weight decay regularization.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        transform.add_decayed_weights(weight_decay),
        transform.scale_by_trust_ratio(),
        _scale_by_learning_rate(learning_rate),
    )
Example #10
0
def adafactor(
    learning_rate: Optional[ScalarOrSchedule] = None,
    min_dim_size_to_factor: int = 128,
    decay_rate: float = 0.8,
    decay_offset: int = 0,
    multiply_by_parameter_scale: float = True,
    clipping_threshold: Optional[float] = 1.0,
    momentum: Optional[float] = None,
    dtype_momentum: Any = jnp.float32,
    weight_decay_rate: Optional[float] = None,
    eps: float = 1e-30,
    factored: bool = True,
    weight_decay_mask: MaskOrFn = None,
    ) -> base.GradientTransformation:
  """The Adafactor optimiser.

  Adafactor is an adaptive learning rate optimiser that focuses on fast
  training of large scale neural networks. It saves memory by using a factored
  estimate of the second order moments used to scale gradients.

  References:
    Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235

  Args:
      learning_rate: (float) a step size. Note: the natural scale for
        Adafactor's LR is markedly different from Adam, one doesn't use the
        1/sqrt(hidden) correction for this optim with attention-based models.
      min_dim_size_to_factor: (int) only factor the statistics if two array
        dimensions have at least this size.
      decay_rate: (float) controls second-moment exponential decay schedule.
      decay_offset: (int) for finetuning, one may set this to the starting
        step number of the finetuning phase.
      multiply_by_parameter_scale: (bool): if True, then scale learning_rate by
        parameter norm. if False, provided learning_rate is absolute step size.
      clipping_threshold: (float>=1) optional value; if None, clipping disabled.
      momentum: (float) optional value between 0 and 1, enables
        momentum and uses extra memory if non-None! None by default.
      dtype_momentum: (dtype) dtype of momentum buffers.
      weight_decay_rate: (float) optional rate at which to decay weights.
      eps: (float) regularization constant for root mean squared gradient.
      factored: (bool) whether to use factored second-moment estimates.
      weight_decay_mask: a tree with same structure as (or a prefix of)
        the params PyTree, or a Callable that returns such a pytree given
        the params/updates. The leaves should be booleans, `True`
        for leaves/subtrees you want to apply the transformation to,
        and `False` for those you want to skip.

  Returns:
    the corresponding `GradientTransformation`.
  """
  # The core of the algorithm is a procedure for rescaling gradients
  # by a factored estimate of the root mean squared gradients.
  # This reduces memory compared to algorithms such as Adam or RmsProp,
  # by not having to hold a separate estimate for each weight.
  tx = [
      factorized.scale_by_factored_rms(
          factored, decay_rate, decay_offset, min_dim_size_to_factor, eps)]
  # This basic rescaling is typically combined with one or more of the following
  # transformation (all can be disabled via adafactor's constructor args).
  if clipping_threshold is not None:
    tx.append(clipping.clip_by_block_rms(clipping_threshold))
  if learning_rate is not None:
    tx.append(_scale_by_learning_rate(learning_rate, flip_sign=False))
  if multiply_by_parameter_scale:
    tx.append(transform.scale_by_param_block_rms())
  if momentum is not None:
    tx.append(
        transform.ema(momentum, debias=False, accumulator_dtype=dtype_momentum))
  if weight_decay_rate is not None:
    tx.append(transform.add_decayed_weights(
        weight_decay_rate, mask=weight_decay_mask))
  # In gradient "descent" we follow the negative gradient.
  tx.append(transform.scale(-1))
  return combine.chain(*tx)