Example #1
0
 def _resource_apply_dense(self, grad, var):
     rms = self.get_slot(var, "rms")
     mom = self.get_slot(var, "momentum")
     if self._centered:
         mg = self.get_slot(var, "mg")
         return training_ops.resource_apply_centered_rms_prop(
             var.handle,
             mg.handle,
             rms.handle,
             mom.handle,
             math_ops.cast(self._learning_rate_tensor,
                           grad.dtype.base_dtype),
             math_ops.cast(self._decay_tensor, grad.dtype.base_dtype),
             math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
             math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype),
             grad,
             use_locking=self._use_locking)
     else:
         return training_ops.resource_apply_rms_prop(
             var.handle,
             rms.handle,
             mom.handle,
             math_ops.cast(self._learning_rate_tensor,
                           grad.dtype.base_dtype),
             math_ops.cast(self._decay_tensor, grad.dtype.base_dtype),
             math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
             math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype),
             grad,
             use_locking=self._use_locking)
Example #2
0
 def _resource_apply_dense(self, grad, var, state):
     rms = state.get_slot(var, "rms")
     mom = state.get_slot(var, "momentum")
     if self._centered:
         mg = state.get_slot(var, "mg")
         return training_ops.resource_apply_centered_rms_prop(
             var.handle,
             mg.handle,
             rms.handle,
             mom.handle,
             state.get_hyper("learning_rate", var.dtype.base_dtype),
             state.get_hyper("rho", var.dtype.base_dtype),
             state.get_hyper("momentum", var.dtype.base_dtype),
             0,
             grad,
             use_locking=self._use_locking)
     else:
         return training_ops.resource_apply_rms_prop(
             var.handle,
             rms.handle,
             mom.handle,
             state.get_hyper("learning_rate", var.dtype.base_dtype),
             state.get_hyper("rho", var.dtype.base_dtype),
             state.get_hyper("momentum", var.dtype.base_dtype),
             0,
             grad,
             use_locking=self._use_locking)
Example #3
0
 def _resource_apply_dense(self, grad, var):
   rms = self.get_slot(var, "rms")
   mom = self.get_slot(var, "momentum")
   if self._centered:
     mg = self.get_slot(var, "mg")
     return training_ops.resource_apply_centered_rms_prop(
         var.handle,
         mg.handle,
         rms.handle,
         mom.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
         math_ops.cast(self._decay_tensor, grad.dtype.base_dtype),
         math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
         math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype),
         grad,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_apply_rms_prop(
         var.handle,
         rms.handle,
         mom.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
         math_ops.cast(self._decay_tensor, grad.dtype.base_dtype),
         math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
         math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype),
         grad,
         use_locking=self._use_locking)
Example #4
0
 def _resource_apply_dense(self, grad, var):
     var_dtype = var.dtype.base_dtype
     lr_t = self._decayed_lr(var_dtype)
     rms = self.get_slot(var, "rms")
     mom = self.get_slot(var, "momentum")
     rho = self._get_hyper("rho", var_dtype)
     momentum = self._get_hyper("momentum", var_dtype)
     epsilon = self._get_hyper("epsilon", var_dtype)
     if self.centered:
         mg = self.get_slot(var, "mg")
         return training_ops.resource_apply_centered_rms_prop(
             var.handle,
             mg.handle,
             rms.handle,
             mom.handle,
             lr_t,
             rho,
             momentum,
             epsilon,
             grad,
             use_locking=self._use_locking)
     else:
         return training_ops.resource_apply_rms_prop(
             var.handle,
             rms.handle,
             mom.handle,
             lr_t,
             rho,
             momentum,
             epsilon,
             grad,
             use_locking=self._use_locking)
Example #5
0
 def _resource_apply_dense(self, grad, var):
     rms = self.get_slot(var, "rms")
     mom = self.get_slot(var, "momentum")
     learning_rate = math_ops.cast(self._get_hyper("learning_rate"),
                                   grad.dtype.base_dtype)
     rho = math_ops.cast(self._get_hyper("rho"), grad.dtype.base_dtype)
     momentum = math_ops.cast(self._get_hyper("momentum"),
                              grad.dtype.base_dtype)
     epsilon = math_ops.cast(self._get_hyper("epsilon"),
                             grad.dtype.base_dtype)
     if self._centered:
         mg = self.get_slot(var, "mg")
         return training_ops.resource_apply_centered_rms_prop(
             var.handle,
             mg.handle,
             rms.handle,
             mom.handle,
             learning_rate,
             rho,
             momentum,
             epsilon,
             grad,
             use_locking=self._use_locking)
     else:
         return training_ops.resource_apply_rms_prop(
             var.handle,
             rms.handle,
             mom.handle,
             learning_rate,
             rho,
             momentum,
             epsilon,
             grad,
             use_locking=self._use_locking)
Example #6
0
 def _resource_apply_dense(self, grad, var):
   var_dtype = var.dtype.base_dtype
   lr_t = self._decayed_lr(var_dtype)
   rms = self.get_slot(var, "rms")
   mom = self.get_slot(var, "momentum")
   rho = self._get_hyper("rho", var_dtype)
   momentum = self._get_hyper("momentum", var_dtype)
   epsilon = self._get_hyper("epsilon", var_dtype)
   if self._centered:
     mg = self.get_slot(var, "mg")
     return training_ops.resource_apply_centered_rms_prop(
         var.handle,
         mg.handle,
         rms.handle,
         mom.handle,
         lr_t,
         rho,
         momentum,
         epsilon,
         grad,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_apply_rms_prop(
         var.handle,
         rms.handle,
         mom.handle,
         lr_t,
         rho,
         momentum,
         epsilon,
         grad,
         use_locking=self._use_locking)
Example #7
0
 def _resource_apply_dense(self, grad, var, state):
   rms = state.get_slot(var, "rms")
   mom = state.get_slot(var, "momentum")
   if self._centered:
     mg = state.get_slot(var, "mg")
     return training_ops.resource_apply_centered_rms_prop(
         var.handle,
         mg.handle,
         rms.handle,
         mom.handle,
         state.get_hyper("learning_rate", var.dtype.base_dtype),
         state.get_hyper("rho", var.dtype.base_dtype),
         state.get_hyper("momentum", var.dtype.base_dtype),
         0,
         grad,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_apply_rms_prop(
         var.handle,
         rms.handle,
         mom.handle,
         state.get_hyper("learning_rate", var.dtype.base_dtype),
         state.get_hyper("rho", var.dtype.base_dtype),
         state.get_hyper("momentum", var.dtype.base_dtype),
         0,
         grad,
         use_locking=self._use_locking)
Example #8
0
 def _resource_apply_dense(self, grad, var):
   rms = self.get_slot(var, "rms")
   mom = self.get_slot(var, "momentum")
   learning_rate = math_ops.cast(
       self._get_hyper("learning_rate"), grad.dtype.base_dtype)
   rho = math_ops.cast(self._get_hyper("rho"), grad.dtype.base_dtype)
   momentum = math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype)
   epsilon = math_ops.cast(self._get_hyper("epsilon"), grad.dtype.base_dtype)
   if self._centered:
     mg = self.get_slot(var, "mg")
     return training_ops.resource_apply_centered_rms_prop(
         var.handle,
         mg.handle,
         rms.handle,
         mom.handle,
         learning_rate,
         rho,
         momentum,
         epsilon,
         grad,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_apply_rms_prop(
         var.handle,
         rms.handle,
         mom.handle,
         learning_rate,
         rho,
         momentum,
         epsilon,
         grad,
         use_locking=self._use_locking)
Example #9
0
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = ((apply_state or {}).get((var_device, var_dtype))
                        or self._fallback_apply_state(var_device, var_dtype))

        rms = self.get_slot(var, "rms")
        if self._momentum:
            mom = self.get_slot(var, "momentum")
            if self.centered:
                mg = self.get_slot(var, "mg")
                return training_ops.resource_apply_centered_rms_prop(
                    var.handle,
                    mg.handle,
                    rms.handle,
                    mom.handle,
                    coefficients["lr_t"],
                    coefficients["rho"],
                    coefficients["momentum"],
                    coefficients["epsilon"],
                    grad,
                    use_locking=self._use_locking)
            else:
                return training_ops.resource_apply_rms_prop(
                    var.handle,
                    rms.handle,
                    mom.handle,
                    coefficients["lr_t"],
                    coefficients["rho"],
                    coefficients["momentum"],
                    coefficients["epsilon"],
                    grad,
                    use_locking=self._use_locking)
        else:
            rms_t = (coefficients["rho"] * rms +
                     coefficients["one_minus_rho"] * math_ops.square(grad))
            rms_t = state_ops.assign(rms, rms_t, use_locking=self._use_locking)
            denom_t = rms_t
            if self.centered:
                mg = self.get_slot(var, "mg")
                mg_t = coefficients["rho"] * mg + coefficients[
                    "one_minus_rho"] * grad
                mg_t = state_ops.assign(mg,
                                        mg_t,
                                        use_locking=self._use_locking)
                denom_t = rms_t - math_ops.square(mg_t)
            var_t = var - coefficients["lr_t"] * grad / (
                math_ops.sqrt(denom_t) + coefficients["epsilon"])
            return state_ops.assign(var, var_t,
                                    use_locking=self._use_locking).op
Example #10
0
 def _resource_apply_dense(self, grad, var):
     var_dtype = var.dtype.base_dtype
     lr_t = self._decayed_lr(var_dtype)
     rms = self.get_slot(var, "rms")
     rho = self._get_hyper("rho", var_dtype)
     momentum = self._get_hyper("momentum", var_dtype)
     epsilon = self._get_hyper("epsilon", var_dtype)
     if self._momentum:
         mom = self.get_slot(var, "momentum")
         if self.centered:
             mg = self.get_slot(var, "mg")
             return training_ops.resource_apply_centered_rms_prop(
                 var.handle,
                 mg.handle,
                 rms.handle,
                 mom.handle,
                 lr_t,
                 rho,
                 momentum,
                 epsilon,
                 grad,
                 use_locking=self._use_locking)
         else:
             return training_ops.resource_apply_rms_prop(
                 var.handle,
                 rms.handle,
                 mom.handle,
                 lr_t,
                 rho,
                 momentum,
                 epsilon,
                 grad,
                 use_locking=self._use_locking)
     else:
         rms_t = rho * rms + (1. - rho) * math_ops.square(grad)
         rms_t = state_ops.assign(rms, rms_t, use_locking=self._use_locking)
         denom_t = rms_t
         if self.centered:
             mg = self.get_slot(var, "mg")
             mg_t = rho * mg + (1. - rho) * grad
             mg_t = state_ops.assign(mg,
                                     mg_t,
                                     use_locking=self._use_locking)
             denom_t = rms_t - math_ops.square(mg_t)
         var_t = var - lr_t * grad / (math_ops.sqrt(denom_t) + epsilon)
         return state_ops.assign(var, var_t,
                                 use_locking=self._use_locking).op
Example #11
0
 def _resource_apply_dense(self, grad, var):
   var_dtype = var.dtype.base_dtype
   lr_t = self._decayed_lr(var_dtype)
   rms = self.get_slot(var, "rms")
   rho = self._get_hyper("rho", var_dtype)
   momentum = self._get_hyper("momentum", var_dtype)
   epsilon = self._get_hyper("epsilon", var_dtype)
   if self._momentum:
     mom = self.get_slot(var, "momentum")
     if self.centered:
       mg = self.get_slot(var, "mg")
       return training_ops.resource_apply_centered_rms_prop(
           var.handle,
           mg.handle,
           rms.handle,
           mom.handle,
           lr_t,
           rho,
           momentum,
           epsilon,
           grad,
           use_locking=self._use_locking)
     else:
       return training_ops.resource_apply_rms_prop(
           var.handle,
           rms.handle,
           mom.handle,
           lr_t,
           rho,
           momentum,
           epsilon,
           grad,
           use_locking=self._use_locking)
   else:
     rms_t = rho * rms + (1. - rho) * math_ops.square(grad)
     rms_t = state_ops.assign(rms, rms_t, use_locking=self._use_locking)
     denom_t = rms_t
     if self.centered:
       mg = self.get_slot(var, "mg")
       mg_t = rho * mg + (1. - rho) * grad
       mg_t = state_ops.assign(mg, mg_t, use_locking=self._use_locking)
       denom_t = rms_t - math_ops.square(mg_t)
     var_t = var - lr_t * grad / (math_ops.sqrt(denom_t) + epsilon)
     return state_ops.assign(var, var_t, use_locking=self._use_locking).op