コード例 #1
0
ファイル: alias.py プロジェクト: ksachdeva/optax
def adabelief(learning_rate: ScalarOrSchedule,
              b1: float = 0.9,
              b2: float = 0.999,
              eps: float = 1e-8) -> base.GradientTransformation:
    """The AdaBelief optimiser.

  AdaBelief is an adaptive learning rate optimiser that focuses on fast
  convergence, generalisation, and stability. It adapts the step size depending
  on its "belief" in the gradient direction — the optimiser adaptively scales
  the step size by the difference between the predicted and observed gradients.
  AdaBelief is a modified version of Adam and contains the same number of
  parameters.

  References:
    [Zhuang et al, 2020](https://arxiv.org/abs/2010.07468)

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_belief(b1=b1, b2=b2, eps=eps),
        _scale_by_learning_rate(learning_rate),
    )
コード例 #2
0
ファイル: alias.py プロジェクト: n2cholas/optax
def adabelief(
    learning_rate: ScalarOrSchedule,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-16,
    eps_root: float = 1e-16) -> base.GradientTransformation:
  """The AdaBelief optimiser.

  AdaBelief is an adaptive learning rate optimiser that focuses on fast
  convergence, generalisation, and stability. It adapts the step size depending
  on its "belief" in the gradient direction — the optimiser adaptively scales
  the step size by the difference between the predicted and observed gradients.
  AdaBelief is a modified version of Adam and contains the same number of
  parameters.

  References:
    Zhuang et al, 2020: https://arxiv.org/abs/2010.07468

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the second moment of the prediction error to
      improve numerical stability. If backpropagating gradients through the
      gradient transformation (e.g. for meta-learning), this must be non-zero.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.scale_by_belief(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
      _scale_by_learning_rate(learning_rate),
  )
コード例 #3
0
def adabelief(learning_rate: float,
              b1: float = 0.9,
              b2: float = 0.999,
              eps: float = 1e-8) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_belief(b1=b1, b2=b2, eps=eps),
        transform.scale(-learning_rate),
    )