示例#1
0
 def _resource_apply_sparse(self, grad, var, indices):
     rms = self.get_slot(var, "rms")
     mom = self.get_slot(var, "momentum")
     if self._centered:
         mg = self.get_slot(var, "mg")
         return training_ops.resource_sparse_apply_centered_rms_prop(
             var.handle,
             mg.handle,
             rms.handle,
             mom.handle,
             math_ops.cast(self._learning_rate_tensor, grad.dtype),
             math_ops.cast(self._decay_tensor, grad.dtype),
             math_ops.cast(self._momentum_tensor, grad.dtype),
             math_ops.cast(self._epsilon_tensor, grad.dtype),
             grad,
             indices,
             use_locking=self._use_locking)
     else:
         return training_ops.resource_sparse_apply_rms_prop(
             var.handle,
             rms.handle,
             mom.handle,
             math_ops.cast(self._learning_rate_tensor, grad.dtype),
             math_ops.cast(self._decay_tensor, grad.dtype),
             math_ops.cast(self._momentum_tensor, grad.dtype),
             math_ops.cast(self._epsilon_tensor, grad.dtype),
             grad,
             indices,
             use_locking=self._use_locking)
示例#2
0
 def _resource_apply_sparse(self, grad, var, indices, state):
     rms = state.get_slot(var, "rms")
     mom = state.get_slot(var, "momentum")
     if self._centered:
         mg = self.get_slot(var, "mg")
         return training_ops.resource_sparse_apply_centered_rms_prop(
             var.handle,
             mg.handle,
             rms.handle,
             mom.handle,
             state.get_hyper("learning_rate", var.dtype.base_dtype),
             state.get_hyper("rho", var.dtype.base_dtype),
             state.get_hyper("momentum", var.dtype.base_dtype),
             0,
             grad,
             indices,
             use_locking=self._use_locking)
     else:
         return training_ops.resource_sparse_apply_rms_prop(
             var.handle,
             rms.handle,
             mom.handle,
             state.get_hyper("learning_rate", var.dtype.base_dtype),
             state.get_hyper("rho", var.dtype.base_dtype),
             state.get_hyper("momentum", var.dtype.base_dtype),
             0,
             grad,
             indices,
             use_locking=self._use_locking)
示例#3
0
 def _resource_apply_sparse(self, grad, var, indices):
     var_dtype = var.dtype.base_dtype
     lr_t = self._decayed_lr(var_dtype)
     rms = self.get_slot(var, "rms")
     mom = self.get_slot(var, "momentum")
     rho = self._get_hyper("rho", var_dtype)
     momentum = self._get_hyper("momentum", var_dtype)
     epsilon = self._get_hyper("epsilon", var_dtype)
     if self.centered:
         mg = self.get_slot(var, "mg")
         return training_ops.resource_sparse_apply_centered_rms_prop(
             var.handle,
             mg.handle,
             rms.handle,
             mom.handle,
             lr_t,
             rho,
             momentum,
             epsilon,
             grad,
             indices,
             use_locking=self._use_locking)
     else:
         return training_ops.resource_sparse_apply_rms_prop(
             var.handle,
             rms.handle,
             mom.handle,
             lr_t,
             rho,
             momentum,
             epsilon,
             grad,
             indices,
             use_locking=self._use_locking)
示例#4
0
 def _resource_apply_sparse(self, grad, var, indices):
   var_dtype = var.dtype.base_dtype
   lr_t = self._decayed_lr(var_dtype)
   rms = self.get_slot(var, "rms")
   mom = self.get_slot(var, "momentum")
   rho = self._get_hyper("rho", var_dtype)
   momentum = self._get_hyper("momentum", var_dtype)
   epsilon = self._get_hyper("epsilon", var_dtype)
   if self._centered:
     mg = self.get_slot(var, "mg")
     return training_ops.resource_sparse_apply_centered_rms_prop(
         var.handle,
         mg.handle,
         rms.handle,
         mom.handle,
         lr_t,
         rho,
         momentum,
         epsilon,
         grad,
         indices,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_sparse_apply_rms_prop(
         var.handle,
         rms.handle,
         mom.handle,
         lr_t,
         rho,
         momentum,
         epsilon,
         grad,
         indices,
         use_locking=self._use_locking)
示例#5
0
 def _resource_apply_sparse(self, grad, var, indices):
   rms = self.get_slot(var, "rms")
   mom = self.get_slot(var, "momentum")
   if self._centered:
     mg = self.get_slot(var, "mg")
     return training_ops.resource_sparse_apply_centered_rms_prop(
         var.handle,
         mg.handle,
         rms.handle,
         mom.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype),
         math_ops.cast(self._decay_tensor, grad.dtype),
         math_ops.cast(self._momentum_tensor, grad.dtype),
         math_ops.cast(self._epsilon_tensor, grad.dtype),
         grad,
         indices,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_sparse_apply_rms_prop(
         var.handle,
         rms.handle,
         mom.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype),
         math_ops.cast(self._decay_tensor, grad.dtype),
         math_ops.cast(self._momentum_tensor, grad.dtype),
         math_ops.cast(self._epsilon_tensor, grad.dtype),
         grad,
         indices,
         use_locking=self._use_locking)
示例#6
0
 def _resource_apply_sparse(self, grad, var, indices, state):
   rms = state.get_slot(var, "rms")
   mom = state.get_slot(var, "momentum")
   if self._centered:
     mg = self.get_slot(var, "mg")
     return training_ops.resource_sparse_apply_centered_rms_prop(
         var.handle,
         mg.handle,
         rms.handle,
         mom.handle,
         state.get_hyper("learning_rate", var.dtype.base_dtype),
         state.get_hyper("rho", var.dtype.base_dtype),
         state.get_hyper("momentum", var.dtype.base_dtype),
         0,
         grad,
         indices,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_sparse_apply_rms_prop(
         var.handle,
         rms.handle,
         mom.handle,
         state.get_hyper("learning_rate", var.dtype.base_dtype),
         state.get_hyper("rho", var.dtype.base_dtype),
         state.get_hyper("momentum", var.dtype.base_dtype),
         0,
         grad,
         indices,
         use_locking=self._use_locking)
示例#7
0
 def _resource_apply_sparse(self, grad, var, indices):
   rms = self.get_slot(var, "rms")
   mom = self.get_slot(var, "momentum")
   learning_rate = math_ops.cast(
       self._get_hyper("learning_rate"), grad.dtype.base_dtype)
   rho = math_ops.cast(self._get_hyper("rho"), grad.dtype.base_dtype)
   momentum = math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype)
   epsilon = math_ops.cast(self._get_hyper("epsilon"), grad.dtype.base_dtype)
   if self._centered:
     mg = self.get_slot(var, "mg")
     return training_ops.resource_sparse_apply_centered_rms_prop(
         var.handle,
         mg.handle,
         rms.handle,
         mom.handle,
         learning_rate,
         rho,
         momentum,
         epsilon,
         grad,
         indices,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_sparse_apply_rms_prop(
         var.handle,
         rms.handle,
         mom.handle,
         learning_rate,
         rho,
         momentum,
         epsilon,
         grad,
         indices,
         use_locking=self._use_locking)
示例#8
0
 def _resource_apply_sparse(self, grad, var, indices):
   rms = self.get_slot(var, "rms")
   mom = self.get_slot(var, "momentum")
   learning_rate = math_ops.cast(
       self._get_hyper("learning_rate"), grad.dtype.base_dtype)
   rho = math_ops.cast(self._get_hyper("rho"), grad.dtype.base_dtype)
   momentum = math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype)
   epsilon = math_ops.cast(self._get_hyper("epsilon"), grad.dtype.base_dtype)
   if self._centered:
     mg = self.get_slot(var, "mg")
     return training_ops.resource_sparse_apply_centered_rms_prop(
         var.handle,
         mg.handle,
         rms.handle,
         mom.handle,
         learning_rate,
         rho,
         momentum,
         epsilon,
         grad,
         indices,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_sparse_apply_rms_prop(
         var.handle,
         rms.handle,
         mom.handle,
         learning_rate,
         rho,
         momentum,
         epsilon,
         grad,
         indices,
         use_locking=self._use_locking)
示例#9
0
  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
    var_device, var_dtype = var.device, var.dtype.base_dtype
    coefficients = ((apply_state or {}).get((var_device, var_dtype))
                    or self._fallback_apply_state(var_device, var_dtype))

    rms = self.get_slot(var, "rms")
    if self._momentum:
      mom = self.get_slot(var, "momentum")
      if self.centered:
        mg = self.get_slot(var, "mg")
        return training_ops.resource_sparse_apply_centered_rms_prop(
            var.handle,
            mg.handle,
            rms.handle,
            mom.handle,
            coefficients["lr_t"],
            coefficients["rho"],
            coefficients["momentum"],
            coefficients["epsilon"],
            grad,
            indices,
            use_locking=self._use_locking)
      else:
        return training_ops.resource_sparse_apply_rms_prop(
            var.handle,
            rms.handle,
            mom.handle,
            coefficients["lr_t"],
            coefficients["rho"],
            coefficients["momentum"],
            coefficients["epsilon"],
            grad,
            indices,
            use_locking=self._use_locking)
    else:
      rms_scaled_g_values = (grad * grad) * coefficients["one_minus_rho"]
      rms_t = state_ops.assign(rms, rms * coefficients["rho"],
                               use_locking=self._use_locking)
      with ops.control_dependencies([rms_t]):
        rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values)
        rms_slice = array_ops.gather(rms_t, indices)
      denom_slice = rms_slice
      if self.centered:
        mg = self.get_slot(var, "mg")
        mg_scaled_g_values = grad * coefficients["one_minus_rho"]
        mg_t = state_ops.assign(mg, mg * coefficients["rho"],
                                use_locking=self._use_locking)
        with ops.control_dependencies([mg_t]):
          mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values)
          mg_slice = array_ops.gather(mg_t, indices)
          denom_slice = rms_slice - math_ops.square(mg_slice)
      var_update = self._resource_scatter_add(
          var, indices, coefficients["neg_lr_t"] * grad / (
              math_ops.sqrt(denom_slice) + coefficients["epsilon"]))
      if self.centered:
        return control_flow_ops.group(*[var_update, rms_t, mg_t])
      return control_flow_ops.group(*[var_update, rms_t])
示例#10
0
 def _resource_apply_sparse(self, grad, var, indices):
   var_dtype = var.dtype.base_dtype
   lr_t = self._decayed_lr(var_dtype)
   rms = self.get_slot(var, "rms")
   rho = self._get_hyper("rho", var_dtype)
   momentum = self._get_hyper("momentum", var_dtype)
   epsilon = self._get_hyper("epsilon", var_dtype)
   if self._momentum:
     mom = self.get_slot(var, "momentum")
     if self.centered:
       mg = self.get_slot(var, "mg")
       return training_ops.resource_sparse_apply_centered_rms_prop(
           var.handle,
           mg.handle,
           rms.handle,
           mom.handle,
           lr_t,
           rho,
           momentum,
           epsilon,
           grad,
           indices,
           use_locking=self._use_locking)
     else:
       return training_ops.resource_sparse_apply_rms_prop(
           var.handle,
           rms.handle,
           mom.handle,
           lr_t,
           rho,
           momentum,
           epsilon,
           grad,
           indices,
           use_locking=self._use_locking)
   else:
     rms_scaled_g_values = (grad * grad) * (1. - rho)
     rms_t = state_ops.assign(rms, rms * rho, use_locking=self._use_locking)
     with ops.control_dependencies([rms_t]):
       rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values)
       rms_slice = array_ops.gather(rms_t, indices)
     denom_slice = rms_slice
     if self.centered:
       mg = self.get_slot(var, "mg")
       mg_scaled_g_values = grad * (1. - rho)
       mg_t = state_ops.assign(mg, mg * rho, use_locking=self._use_locking)
       with ops.control_dependencies([mg_t]):
         mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values)
         mg_slice = array_ops.gather(mg_t, indices)
         denom_slice = rms_slice - math_ops.square(mg_slice)
     var_update = self._resource_scatter_add(
         var, indices, -lr_t * grad / (math_ops.sqrt(denom_slice) + epsilon))
     if self.centered:
       return control_flow_ops.group(*[var_update, rms_t, mg_t])
     return control_flow_ops.group(*[var_update, rms_t])
示例#11
0
 def _resource_apply_sparse(self, grad, var, indices):
   var_dtype = var.dtype.base_dtype
   lr_t = self._decayed_lr(var_dtype)
   rms = self.get_slot(var, "rms")
   rho = self._get_hyper("rho", var_dtype)
   momentum = self._get_hyper("momentum", var_dtype)
   epsilon = self._get_hyper("epsilon", var_dtype)
   if self._momentum:
     mom = self.get_slot(var, "momentum")
     if self.centered:
       mg = self.get_slot(var, "mg")
       return training_ops.resource_sparse_apply_centered_rms_prop(
           var.handle,
           mg.handle,
           rms.handle,
           mom.handle,
           lr_t,
           rho,
           momentum,
           epsilon,
           grad,
           indices,
           use_locking=self._use_locking)
     else:
       return training_ops.resource_sparse_apply_rms_prop(
           var.handle,
           rms.handle,
           mom.handle,
           lr_t,
           rho,
           momentum,
           epsilon,
           grad,
           indices,
           use_locking=self._use_locking)
   else:
     rms_scaled_g_values = (grad * grad) * (1. - rho)
     rms_t = state_ops.assign(rms, rms * rho, use_locking=self._use_locking)
     with ops.control_dependencies([rms_t]):
       rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values)
       rms_slice = array_ops.gather(rms_t, indices)
     denom_slice = rms_slice
     if self.centered:
       mg = self.get_slot(var, "mg")
       mg_scaled_g_values = grad * (1. - rho)
       mg_t = state_ops.assign(mg, mg * rho, use_locking=self._use_locking)
       with ops.control_dependencies([mg_t]):
         mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values)
         mg_slice = array_ops.gather(mg_t, indices)
         denom_slice = rms_slice - math_ops.square(mg_slice)
     var_update = self._resource_scatter_add(
         var, indices, -lr_t * grad / (math_ops.sqrt(denom_slice) + epsilon))
     if self.centered:
       return control_flow_ops.group(*[var_update, rms_t, mg_t])
     return control_flow_ops.group(*[var_update, rms_t])