def rmsprop( learning_rate: ScalarOrSchedule, decay: float = 0.9, eps: float = 1e-8, initial_scale: float = 0., centered: bool = False, momentum: Optional[float] = None, nesterov: bool = False ) -> base.GradientTransformation: # pylint: disable=line-too-long """A flexible RMSProp optimiser. RMSProp is an SGD variant with learning rate adaptation. The `learning_rate` used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimiser that can be used to switch between several of these variants. References: Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf Graves, 2013: https://arxiv.org/abs/1308.0850 Args: learning_rate: this is a fixed global scaling factor. decay: the decay used to track the magnitude of previous gradients. eps: a small numerical constant to avoid dividing by zero when rescaling. initial_scale: (default `0.`), initialisation of accumulators tracking the magnitude of previous updates. PyTorch uses `0`, TF1 uses `1`. When reproducing results from a paper, verify the value used by the authors. centered: (default `False`), whether the second moment or the variance of the past gradients is used to rescale the latest gradients. momentum: (default `None`), the `decay` rate used by the momentum term, when it is set to `None`, then momentum is not used at all. nesterov (default `False`): whether nesterov momentum is used. Returns: the corresponding `GradientTransformation`. """ # pylint: enable=line-too-long if centered: return combine.chain( transform.scale_by_stddev( decay=decay, eps=eps, initial_scale=initial_scale), _scale_by_learning_rate(learning_rate), (transform.trace(decay=momentum, nesterov=nesterov) if momentum is not None else base.identity()) ) return combine.chain( transform.scale_by_rms( decay=decay, eps=eps, initial_scale=initial_scale), _scale_by_learning_rate(learning_rate), (transform.trace(decay=momentum, nesterov=nesterov) if momentum is not None else base.identity()) )
def rmsprop(learning_rate: float, decay: float = 0.9, eps: float = 1e-8, centered: bool = False) -> GradientTransformation: if centered: return combine.chain( transform.scale_by_stddev(decay=decay, eps=eps), transform.scale(-learning_rate), ) return combine.chain( transform.scale_by_rms(decay=decay, eps=eps), transform.scale(-learning_rate), )