def _apply_sparse(self, grad, var): rms = self.get_slot(var, "rms") mom = self.get_slot(var, "momentum") if self._centered: mg = self.get_slot(var, "mg") return training_ops.sparse_apply_centered_rms_prop( var, mg, rms, mom, math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._decay_tensor, var.dtype.base_dtype), math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype), grad.values, grad.indices, use_locking=self._use_locking) else: return training_ops.sparse_apply_rms_prop( var, rms, mom, math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), math_ops.cast(self._decay_tensor, var.dtype.base_dtype), math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype), grad.values, grad.indices, use_locking=self._use_locking)
def _apply_sparse(self, grad, var, state): rms = state.get_slot(var, "rms") mom = state.get_slot(var, "momentum") if self._centered: mg = state.get_slot(var, "mg") return training_ops.sparse_apply_centered_rms_prop( var, mg, rms, mom, state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("rho", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), 0, grad.values, grad.indices, use_locking=self._use_locking) else: return training_ops.sparse_apply_rms_prop( var, rms, mom, state.get_hyper("learning_rate", var.dtype.base_dtype), state.get_hyper("rho", var.dtype.base_dtype), state.get_hyper("momentum", var.dtype.base_dtype), 0, grad.values, grad.indices, use_locking=self._use_locking)