示例#1
0
 def _resource_apply_dense(self, grad, var):
   accum = self.get_slot(var, "accum")
   linear = self.get_slot(var, "linear")
   if self._l2_shrinkage_regularization_strength <= 0.0:
     return training_ops.resource_apply_ftrl(
         var.handle,
         accum.handle,
         linear.handle,
         grad,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
         math_ops.cast(self._l1_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._l2_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
         use_locking=self._use_locking)
   else:
     return training_ops.resource_apply_ftrl_v2(
         var.handle,
         accum.handle,
         linear.handle,
         grad,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
         math_ops.cast(self._l1_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._l2_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
         use_locking=self._use_locking)
示例#2
0
 def _resource_apply_dense(self, grad, var):
     var_dtype = var.dtype.base_dtype
     lr_t = self._decayed_lr(var_dtype)
     learning_rate_power = self._get_hyper('learning_rate_power', var_dtype)
     l1_regularization_strength = self._get_hyper(
         'l1_regularization_strength', var_dtype)
     l2_regularization_strength = self._get_hyper(
         'l2_regularization_strength', var_dtype)
     accum = self.get_slot(var, 'accumulator')
     linear = self.get_slot(var, 'linear')
     if self._l2_shrinkage_regularization_strength <= 0.0:
         return training_ops.resource_apply_ftrl(
             var.handle,
             accum.handle,
             linear.handle,
             grad,
             lr_t,
             l1_regularization_strength,
             l2_regularization_strength,
             learning_rate_power,
             use_locking=self._use_locking)
     else:
         return training_ops.resource_apply_ftrl_v2(
             var.handle,
             accum.handle,
             linear.handle,
             grad,
             lr_t,
             l1_regularization_strength,
             l2_regularization_strength,
             math_ops.cast(self._l2_shrinkage_regularization_strength,
                           var_dtype),
             learning_rate_power,
             use_locking=self._use_locking)
示例#3
0
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = ((apply_state or {}).get((var_device, var_dtype))
                        or self._fallback_apply_state(var_device, var_dtype))

        accum = self.get_slot(var, 'accumulator')
        linear = self.get_slot(var, 'linear')

        if self._l2_shrinkage_regularization_strength <= 0.0:
            return training_ops.resource_apply_ftrl(
                var.handle,
                accum.handle,
                linear.handle,
                grad,
                coefficients['lr_t'],
                coefficients['l1_regularization_strength'],
                coefficients['l2_regularization_strength'],
                coefficients['learning_rate_power'],
                use_locking=self._use_locking)
        else:
            return training_ops.resource_apply_ftrl_v2(
                var.handle,
                accum.handle,
                linear.handle,
                grad,
                coefficients['lr_t'],
                coefficients['l1_regularization_strength'],
                coefficients['l2_regularization_strength'],
                coefficients['l2_shrinkage_regularization_strength'],
                coefficients['learning_rate_power'],
                use_locking=self._use_locking)
示例#4
0
 def _resource_apply_dense(self, grad, var):
   accum = self.get_slot(var, "accum")
   linear = self.get_slot(var, "linear")
   if self._l2_shrinkage_regularization_strength <= 0.0:
     return training_ops.resource_apply_ftrl(
         var.handle,
         accum.handle,
         linear.handle,
         grad,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
         math_ops.cast(self._l1_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._l2_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
         use_locking=self._use_locking)
   else:
     return training_ops.resource_apply_ftrl_v2(
         var.handle,
         accum.handle,
         linear.handle,
         grad,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
         math_ops.cast(self._l1_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._l2_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
                       var.dtype.base_dtype),
         math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
         use_locking=self._use_locking)
示例#5
0
 def _resource_apply_dense(self, grad, var):
   var_dtype = var.dtype.base_dtype
   lr_t = self._decayed_lr(var_dtype)
   learning_rate_power = self._get_hyper('learning_rate_power', var_dtype)
   l1_regularization_strength = self._get_hyper('l1_regularization_strength',
                                                var_dtype)
   l2_regularization_strength = self._get_hyper('l2_regularization_strength',
                                                var_dtype)
   accum = self.get_slot(var, 'accumulator')
   linear = self.get_slot(var, 'linear')
   if self._l2_shrinkage_regularization_strength <= 0.0:
     return training_ops.resource_apply_ftrl(
         var.handle,
         accum.handle,
         linear.handle,
         grad,
         lr_t,
         l1_regularization_strength,
         l2_regularization_strength,
         learning_rate_power,
         use_locking=self._use_locking)
   else:
     return training_ops.resource_apply_ftrl_v2(
         var.handle,
         accum.handle,
         linear.handle,
         grad,
         lr_t,
         l1_regularization_strength,
         l2_regularization_strength,
         math_ops.cast(self._l2_shrinkage_regularization_strength, var_dtype),
         learning_rate_power,
         use_locking=self._use_locking)
 def _eval(self, var, accum, linear, grad, lr, l1, l2, l2_shrinkage=0,
           lr_power=1, multiply_linear_by_lr=False):
   dtype = np.float32
   var = np.array(var, dtype=dtype)
   accum = np.array(accum, dtype=dtype)
   linear = np.array(linear, dtype=dtype)
   grad = np.array(grad, dtype=dtype)
   use_v2 = bool(l2_shrinkage)
   with self.session() as session:
     with self.test_scope():
       lr = constant_op.constant(lr, dtype=dtype)
       l1 = constant_op.constant(l1, dtype=dtype)
       l2 = constant_op.constant(l2, dtype=dtype)
       l2_shrinkage = constant_op.constant(l2_shrinkage, dtype=dtype)
       lr_power = constant_op.constant(lr_power, dtype=dtype)
       v_var = resource_variable_ops.ResourceVariable(var, dtype=dtype)
       v_accum = resource_variable_ops.ResourceVariable(accum, dtype=dtype)
       v_linear = resource_variable_ops.ResourceVariable(linear, dtype=dtype)
       session.run(v_var.create)
       session.run(v_accum.create)
       session.run(v_linear.create)
       assert not (use_v2 and multiply_linear_by_lr)
       if use_v2:
         session.run(training_ops.resource_apply_ftrl_v2(
             v_var.handle, v_accum.handle, v_linear.handle,
             grad, lr, l1, l2, l2_shrinkage, lr_power,
             multiply_linear_by_lr=multiply_linear_by_lr))
       else:
         session.run(training_ops.resource_apply_ftrl(
             v_var.handle, v_accum.handle, v_linear.handle,
             grad, lr, l1, l2, lr_power,
             multiply_linear_by_lr=multiply_linear_by_lr))
       return (v_var.read_value().eval().reshape(var.shape),
               v_accum.read_value().eval().reshape(accum.shape),
               v_linear.read_value().eval().reshape(linear.shape))