def adabelief(learning_rate: ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8) -> base.GradientTransformation: """The AdaBelief optimiser. AdaBelief is an adaptive learning rate optimiser that focuses on fast convergence, generalisation, and stability. It adapts the step size depending on its "belief" in the gradient direction — the optimiser adaptively scales the step size by the difference between the predicted and observed gradients. AdaBelief is a modified version of Adam and contains the same number of parameters. References: [Zhuang et al, 2020](https://arxiv.org/abs/2010.07468) Args: learning_rate: this is a fixed global scaling factor. b1: the exponential decay rate to track the first moment of past gradients. b2: the exponential decay rate to track the second moment of past gradients. eps: a small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. Returns: the corresponding `GradientTransformation`. """ return combine.chain( transform.scale_by_belief(b1=b1, b2=b2, eps=eps), _scale_by_learning_rate(learning_rate), )
def adabelief( learning_rate: ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-16, eps_root: float = 1e-16) -> base.GradientTransformation: """The AdaBelief optimiser. AdaBelief is an adaptive learning rate optimiser that focuses on fast convergence, generalisation, and stability. It adapts the step size depending on its "belief" in the gradient direction — the optimiser adaptively scales the step size by the difference between the predicted and observed gradients. AdaBelief is a modified version of Adam and contains the same number of parameters. References: Zhuang et al, 2020: https://arxiv.org/abs/2010.07468 Args: learning_rate: this is a fixed global scaling factor. b1: the exponential decay rate to track the first moment of past gradients. b2: the exponential decay rate to track the second moment of past gradients. eps: term added to the denominator to improve numerical stability. eps_root: term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero. Returns: the corresponding `GradientTransformation`. """ return combine.chain( transform.scale_by_belief(b1=b1, b2=b2, eps=eps, eps_root=eps_root), _scale_by_learning_rate(learning_rate), )
def adabelief(learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8) -> GradientTransformation: return combine.chain( transform.scale_by_belief(b1=b1, b2=b2, eps=eps), transform.scale(-learning_rate), )