Пример #1
0
 def _apply_sparse(self, grad, var):
     rms = self.get_slot(var, "rms")
     mom = self.get_slot(var, "momentum")
     if self._centered:
         mg = self.get_slot(var, "mg")
         return training_ops.sparse_apply_centered_rms_prop(
             var,
             mg,
             rms,
             mom,
             math_ops.cast(self._learning_rate_tensor,
                           var.dtype.base_dtype),
             math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
             math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
             math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
             grad.values,
             grad.indices,
             use_locking=self._use_locking)
     else:
         return training_ops.sparse_apply_rms_prop(
             var,
             rms,
             mom,
             math_ops.cast(self._learning_rate_tensor,
                           var.dtype.base_dtype),
             math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
             math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
             math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
             grad.values,
             grad.indices,
             use_locking=self._use_locking)
Пример #2
0
 def _apply_sparse(self, grad, var):
   rms = self.get_slot(var, "rms")
   mom = self.get_slot(var, "momentum")
   if self._centered:
     mg = self.get_slot(var, "mg")
     return training_ops.sparse_apply_centered_rms_prop(
         var,
         mg,
         rms,
         mom,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
         math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
         math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
         math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
         grad.values,
         grad.indices,
         use_locking=self._use_locking)
   else:
     return training_ops.sparse_apply_rms_prop(
         var,
         rms,
         mom,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
         math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
         math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
         math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
         grad.values,
         grad.indices,
         use_locking=self._use_locking)
Пример #3
0
 def _apply_sparse(self, grad, var, state):
     rms = state.get_slot(var, "rms")
     mom = state.get_slot(var, "momentum")
     if self._centered:
         mg = state.get_slot(var, "mg")
         return training_ops.sparse_apply_centered_rms_prop(
             var,
             mg,
             rms,
             mom,
             state.get_hyper("learning_rate", var.dtype.base_dtype),
             state.get_hyper("rho", var.dtype.base_dtype),
             state.get_hyper("momentum", var.dtype.base_dtype),
             0,
             grad.values,
             grad.indices,
             use_locking=self._use_locking)
     else:
         return training_ops.sparse_apply_rms_prop(
             var,
             rms,
             mom,
             state.get_hyper("learning_rate", var.dtype.base_dtype),
             state.get_hyper("rho", var.dtype.base_dtype),
             state.get_hyper("momentum", var.dtype.base_dtype),
             0,
             grad.values,
             grad.indices,
             use_locking=self._use_locking)
Пример #4
0
 def _apply_sparse(self, grad, var, state):
   rms = state.get_slot(var, "rms")
   mom = state.get_slot(var, "momentum")
   if self._centered:
     mg = state.get_slot(var, "mg")
     return training_ops.sparse_apply_centered_rms_prop(
         var,
         mg,
         rms,
         mom,
         state.get_hyper("learning_rate", var.dtype.base_dtype),
         state.get_hyper("rho", var.dtype.base_dtype),
         state.get_hyper("momentum", var.dtype.base_dtype),
         0,
         grad.values,
         grad.indices,
         use_locking=self._use_locking)
   else:
     return training_ops.sparse_apply_rms_prop(
         var,
         rms,
         mom,
         state.get_hyper("learning_rate", var.dtype.base_dtype),
         state.get_hyper("rho", var.dtype.base_dtype),
         state.get_hyper("momentum", var.dtype.base_dtype),
         0,
         grad.values,
         grad.indices,
         use_locking=self._use_locking)