def _resource_apply_sparse(self, grad, var, indices):
   mom = self.get_slot(var, "momentum")
   return training_ops.resource_sparse_apply_momentum(
       var.handle, mom.handle,
       math_ops.cast(self._learning_rate_tensor, grad.dtype),
       grad, indices,
       math_ops.cast(self._momentum_tensor, grad.dtype),
       use_locking=self._use_locking,
       use_nesterov=self._use_nesterov)
예제 #2
0
 def _resource_apply_sparse(self, grad, var, indices):
   mom = self.get_slot(var, "momentum")
   return training_ops.resource_sparse_apply_momentum(
       var.handle, mom.handle,
       math_ops.cast(self._learning_rate_tensor, grad.dtype),
       grad, indices,
       math_ops.cast(self._momentum_tensor, grad.dtype),
       use_locking=self._use_locking,
       use_nesterov=self._use_nesterov)
예제 #3
0
 def _resource_apply_sparse(self, grad, var, indices, state):
   mom = state.get_slot(var, "momentum")
   return training_ops.resource_sparse_apply_momentum(
       var.handle,
       mom.handle,
       state.get_hyper("learning_rate", var.dtype.base_dtype),
       grad,
       indices,
       state.get_hyper("momentum", var.dtype.base_dtype),
       use_locking=self._use_locking,
       use_nesterov=self._use_nesterov)
예제 #4
0
 def _resource_apply_sparse(self, grad, var, indices, state):
     mom = state.get_slot(var, "momentum")
     return training_ops.resource_sparse_apply_momentum(
         var.handle,
         mom.handle,
         state.get_hyper("learning_rate", var.dtype.base_dtype),
         grad,
         indices,
         state.get_hyper("momentum", var.dtype.base_dtype),
         use_locking=self._use_locking,
         use_nesterov=self._use_nesterov)
예제 #5
0
 def _resource_apply_sparse(self, grad, var, indices):
     mom = self.get_slot(var, "momentum")
     use_nesterov = bool(self._serialize_hyperparameter("use_nesterov"))
     return training_ops.resource_sparse_apply_momentum(
         var.handle,
         mom.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype),
         grad,
         indices,
         math_ops.cast(self._momentum_tensor, grad.dtype),
         use_locking=False,
         use_nesterov=use_nesterov,
     )
예제 #6
0
 def _resource_apply_sparse(self, grad, var, indices):
     # This method is only needed for momentum optimization.
     learning_rate = self._get_hyper("learning_rate")
     momentum_var = self.get_slot(var, "momentum")
     return training_ops.resource_sparse_apply_momentum(
         var.handle,
         momentum_var.handle,
         math_ops.cast(learning_rate, grad.dtype.base_dtype),
         grad,
         indices,
         math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype),
         use_locking=self._use_locking,
         use_nesterov=self._nesterov)
예제 #7
0
 def _resource_apply_sparse(self, grad, var, indices):
   # This method is only needed for momentum optimization.
   learning_rate = self._get_hyper("learning_rate")
   momentum_var = self.get_slot(var, "momentum")
   return training_ops.resource_sparse_apply_momentum(
       var.handle,
       momentum_var.handle,
       math_ops.cast(learning_rate, grad.dtype.base_dtype),
       grad,
       indices,
       math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype),
       use_locking=self._use_locking,
       use_nesterov=self._nesterov)
예제 #8
0
 def _resource_apply_sparse(self, grad, var, indices):
   # This method is only needed for momentum optimization.
   var_dtype = var.dtype.base_dtype
   lr_t = self._decayed_lr(var_dtype)
   momentum_var = self.get_slot(var, "momentum")
   return training_ops.resource_sparse_apply_momentum(
       var.handle,
       momentum_var.handle,
       lr_t,
       grad,
       indices,
       self._get_hyper("momentum", var_dtype),
       use_locking=self._use_locking,
       use_nesterov=self._nesterov)
예제 #9
0
 def _resource_apply_sparse(self, grad, var, indices):
     # This method is only needed for momentum optimization.
     var_dtype = var.dtype.base_dtype
     lr_t = self._decayed_lr(var_dtype)
     momentum_var = self.get_slot(var, "momentum")
     return training_ops.resource_sparse_apply_momentum(
         var.handle,
         momentum_var.handle,
         lr_t,
         grad,
         indices,
         self._get_hyper("momentum", var_dtype),
         use_locking=self._use_locking,
         use_nesterov=self._nesterov)
예제 #10
0
  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
    # This method is only needed for momentum optimization.
    var_device, var_dtype = var.device, var.dtype.base_dtype
    coefficients = ((apply_state or {}).get((var_device, var_dtype))
                    or self._fallback_apply_state(var_device, var_dtype))

    momentum_var = self.get_slot(var, "momentum")
    return training_ops.resource_sparse_apply_momentum(
        var.handle,
        momentum_var.handle,
        coefficients["lr_t"],
        grad,
        indices,
        coefficients["momentum"],
        use_locking=self._use_locking,
        use_nesterov=self.nesterov)
예제 #11
0
    def _resource_apply_sparse(self, grad, var, indices):
        momentum_buffer = self.get_slot(var, "momentum")
        learning_rate = math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype)
        momentum = math_ops.cast(self._momentum_tensor, var.dtype.base_dtype)
        nu = math_ops.cast(self._nu_tensor, var.dtype.base_dtype)

        momentum_op = training_ops.resource_sparse_apply_momentum(
            var.handle,
            momentum_buffer.handle,
            nu * (1.0 - momentum) * learning_rate,
            grad,
            indices,
            momentum,
            use_locking=self._use_locking,
            use_nesterov=False,
        )

        with ops.control_dependencies([momentum_op]):
            delta = (nu - 1.0) * learning_rate * grad
            gd_op = resource_variable_ops.resource_scatter_add(var.handle, indices, delta)

        return control_flow_ops.group(momentum_op, gd_op)