Exemplo n.º 1
0
def fromage(learning_rate: float,
            min_norm: float = 1e-6) -> base.GradientTransformation:
    """The Frobenius matched gradient descent (Fromage) optimiser.

  Fromage is a learning algorithm that does not require learning rate tuning.
  The optimiser is based on modelling neural network gradients via deep relative
  trust (a distance function on deep neural networks). Fromage is similar to the
  LARS optimiser and can work on a range of standard neural network benchmarks,
  such as natural language Transformers and generative adversarial networks.

  References:
    Bernstein et al, 2020: https://arxiv.org/abs/2002.03432

  Args:
    learning_rate: this is a fixed global scaling factor.
    min_norm: a minimum value that the norm of the gradient updates and the
    norm of the layer parameters can be clipped to to avoid dividing by zero
    when computing the trust ratio (as in the LARS paper).

  Returns:
    the corresponding `GradientTransformation`.
  """
    mult = 1 / jnp.sqrt(1 + learning_rate**2)
    return combine.chain(
        transform.scale_by_trust_ratio(min_norm),
        _scale_by_learning_rate(learning_rate * mult),
        transform.add_decayed_weights((mult - 1)),
    )
Exemplo n.º 2
0
def fromage(learning_rate: float,
            min_norm: float = 1e-6) -> base.GradientTransformation:
    mult = 1 / jnp.sqrt(1 + learning_rate**2)
    return combine.chain(
        transform.scale_by_trust_ratio(min_norm),
        _scale_by_learning_rate(learning_rate * mult),
        transform.add_decayed_weights((mult - 1)),
    )
Exemplo n.º 3
0
def lamb(learning_rate: float,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-6,
         weight_decay: float = 0.) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps),
        transform.additive_weight_decay(weight_decay),
        transform.scale_by_trust_ratio(),
        transform.scale(-learning_rate),
    )
Exemplo n.º 4
0
def lamb(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-6,
         eps_root: float = 0.0,
         weight_decay: float = 0.) -> base.GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        transform.add_decayed_weights(weight_decay),
        transform.scale_by_trust_ratio(),
        _scale_by_learning_rate(learning_rate),
    )
Exemplo n.º 5
0
def lars(
    learning_rate: ScalarOrSchedule,
    weight_decay: float = 0.,
    weight_decay_mask: MaskOrFn = True,
    trust_coefficient: float = 0.001,
    eps: float = 0.,
    trust_ratio_mask: MaskOrFn = True,
    momentum: float = 0.9,
    nesterov: bool = False,
) -> base.GradientTransformation:
  """The LARS optimiser.

  LAMB is a layer-wise adaptive optimiser introduced to help scale SGD to
  larger batch sizes. LARS later inspired the LAMB optimiser.

  References:
    You et al, 2017: https://arxiv.org/abs/1708.03888

  Args:
    learning_rate: this is a fixed global scaling factor.
    weight_decay (default `0.`): strength of the weight decay regularization.
    weight_decay_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    trust_coefficient: a multiplier for the trust ratio.
    eps: optional additive constant in the trust ratio denominator.
    trust_ratio_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    momentum: the decay rate for momentum.
    nesterov: whether to use Nesterov momentum.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
      wrappers.masked(
          inner=transform.scale_by_trust_ratio(
              trust_coefficient=trust_coefficient, eps=eps),
          mask=trust_ratio_mask),
      _scale_by_learning_rate(learning_rate),
      transform.trace(decay=momentum, nesterov=nesterov),
  )
Exemplo n.º 6
0
def lamb(
    learning_rate: ScalarOrSchedule,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-6,
    eps_root: float = 0.0,
    weight_decay: float = 0.,
    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
    """The LAMB optimiser.

  LAMB is a general purpose layer-wise adaptive large batch optimiser designed
  to provide consistent training performance across a wide range of tasks,
  including those that use attention-based models (such as Transformers) and
  ResNet-50. The optimiser is able to work with small and large batch sizes.
  LAMB was inspired by the LARS learning algorithm.

  References:
    You et al, 2019: https://arxiv.org/abs/1904.00962

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: (default `0.0`), a small constant applied to denominator inside
      the square root (as in RMSProp), to avoid dividing by zero when rescaling.
      This is needed for instance when computing (meta-)gradients through Adam.
    weight_decay (default `0.`): strength of the weight decay regularization.
    mask: a tree with same structure as (or a prefix of) the params PyTree,
      or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        transform.add_decayed_weights(weight_decay=weight_decay, mask=mask),
        transform.scale_by_trust_ratio(),
        _scale_by_learning_rate(learning_rate),
    )
Exemplo n.º 7
0
def lamb(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-6,
         eps_root: float = 0.0,
         weight_decay: float = 0.) -> base.GradientTransformation:
    """The LAMB optimiser.

  LAMB is a general purpose layer-wise adaptive large batch optimiser designed
  to provide consistent training performance across a wide range of tasks,
  including those that use attention-based models (such as Transformers) and
  ResNet-50. The optimiser is able to work with small and large batch sizes.
  LAMB was inspired by the LARS learning algorithm.

  References:
    [You et al, 2019](https://arxiv.org/abs/1904.00962)

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: (default `0.0`), a small constant applied to denominator inside
      the square root (as in RMSProp), to avoid dividing by zero when rescaling.
      This is needed for instance when computing (meta-)gradients through Adam.
    weight_decay (default `0.`): strength of the weight decay regularization.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        transform.add_decayed_weights(weight_decay),
        transform.scale_by_trust_ratio(),
        _scale_by_learning_rate(learning_rate),
    )