Esempio n. 1
0
def radam(learning_rate: ScalarOrSchedule,
          b1: float = 0.9,
          b2: float = 0.999,
          eps: float = 1e-8,
          threshold: float = 5.0) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_radam(b1=b1, b2=b2, eps=eps, threshold=threshold),
        _scale_by_learning_rate(learning_rate),
    )
Esempio n. 2
0
def radam(learning_rate: ScalarOrSchedule,
          b1: float = 0.9,
          b2: float = 0.999,
          eps: float = 1e-8,
          eps_root: float = 0.0,
          threshold: float = 5.0) -> base.GradientTransformation:
    """The Rectified Adam optimiser.

  The adaptive learning rate in Adam has undesirably large variance in early
  stages of training, due to the limited number of training samples used to
  estimate the optimiser's statistics. Rectified Adam addresses this issue
  by analytically reducing the large variance.

  References:
    [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: (default `0`), a small constant applied to denominator inside the
      square root (as in RMSProp), to avoid dividing by zero when rescaling.
      This is needed for instance when computing (meta-)gradients through Adam.
    threshold: the threshold for variance tractability.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_radam(b1=b1,
                                 b2=b2,
                                 eps=eps,
                                 eps_root=eps_root,
                                 threshold=threshold),
        _scale_by_learning_rate(learning_rate),
    )