Beispiel #1
0
 def _resource_apply_dense(self, grad, handle):
     var = handle
     grad = tf.to_float(grad)
     grad_squared = tf.square(grad) + self._epsilon1
     grad_squared_mean = tf.reduce_mean(grad_squared)
     decay_rate = self._call_if_callable(self._decay_rate)
     update_scale = self._call_if_callable(self._learning_rate)
     update_scale = tf.convert_to_tensor(update_scale, name="update_scale")
     update_scale = tf.cast(update_scale,
                            grad_squared_mean.dtype.base_dtype)
     old_val = var
     if var.dtype.base_dtype == tf.bfloat16:
         old_val = tf.to_float(self._parameter_encoding.decode(old_val))
     if self._multiply_by_parameter_scale:
         update_scale *= tf.to_float(self._parameter_scale(old_val))
     # HACK: Make things dependent on grad.
     # This confounds the XLA rewriter and keeps it from fusing computations
     # across different variables.  This fusion is a bad for HBM usage, since
     # it causes the gradients to persist in memory.
     decay_rate += grad_squared_mean * 1e-30
     update_scale += grad_squared_mean * 1e-30
     # END HACK
     mixing_rate = 1.0 - decay_rate
     shape = var.get_shape().as_list()
     updates = []
     if self._should_use_factored_second_moment_estimate(shape):
         grad_squared_row_mean = tf.reduce_mean(grad_squared, -1)
         grad_squared_col_mean = tf.reduce_mean(grad_squared, -2)
         vr = self.get_slot(var, "vr")
         new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
         vc = self.get_slot(var, "vc")
         new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
         vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking)
         vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking)
         updates = [vr_update, vc_update]
         long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
         r_factor = tf.rsqrt(new_vr / long_term_mean)
         c_factor = tf.rsqrt(new_vc)
         x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(
             c_factor, -2)
     else:
         v = self.get_slot(var, "v")
         new_v = decay_rate * v + mixing_rate * grad_squared
         v_update = tf.assign(v, new_v, use_locking=self._use_locking)
         updates = [v_update]
         x = grad * tf.rsqrt(new_v)
     if self._clipping_threshold is not None:
         clipping_denom = tf.maximum(
             1.0,
             reduce_rms(x) / self._clipping_threshold)
         x /= clipping_denom
     subtrahend = update_scale * x
     if self._beta1:
         m = self.get_slot(var, "m")
         new_m = self._beta1 * tf.to_float(m) + (1.0 -
                                                 self._beta1) * subtrahend
         subtrahend = new_m
         new_m = common_layers.cast_like(new_m, var)
         updates.append(tf.assign(m, new_m, use_locking=self._use_locking))
     new_val = tf.to_float(old_val) - subtrahend
     if var.dtype.base_dtype == tf.bfloat16:
         new_val = self._parameter_encoding.encode(new_val,
                                                   self._quantization_noise)
     if self._simulated_quantize_bits:
         new_val = quantization.simulated_quantize(
             var - subtrahend, self._simulated_quantize_bits,
             self._quantization_noise)
     new_val = tf.cast(new_val, var.dtype)
     var_update = tf.assign(var, new_val, use_locking=self._use_locking)
     updates = [var_update] + updates
     return tf.group(*updates)
Beispiel #2
0
 def _resource_apply_dense(self, grad, handle):
   var = handle
   grad = tf.to_float(grad)
   grad_squared = tf.square(grad) + self._epsilon1
   grad_squared_mean = tf.reduce_mean(grad_squared)
   decay_rate = self._decay_rate
   update_scale = self._learning_rate
   old_val = var
   if var.dtype.base_dtype == tf.bfloat16:
     old_val = tf.to_float(self._parameter_encoding.decode(old_val))
   if self._multiply_by_parameter_scale:
     update_scale *= tf.to_float(self._parameter_scale(old_val))
   # HACK: Make things dependent on grad.
   # This confounds the XLA rewriter and keeps it from fusing computations
   # across different variables.  This fusion is a bad for HBM usage, since
   # it causes the gradients to persist in memory.
   decay_rate += grad_squared_mean * 1e-30
   update_scale += grad_squared_mean * 1e-30
   # END HACK
   mixing_rate = 1.0 - decay_rate
   shape = var.get_shape().as_list()
   updates = []
   if self._should_use_factored_second_moment_estimate(shape):
     grad_squared_row_mean = tf.reduce_mean(grad_squared, -1)
     grad_squared_col_mean = tf.reduce_mean(grad_squared, -2)
     vr = self.get_slot(var, "vr")
     new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
     vc = self.get_slot(var, "vc")
     new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
     vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking)
     vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking)
     updates = [vr_update, vc_update]
     long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
     r_factor = tf.rsqrt(new_vr / long_term_mean)
     c_factor = tf.rsqrt(new_vc)
     x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2)
   else:
     v = self.get_slot(var, "v")
     new_v = decay_rate * v + mixing_rate * grad_squared
     v_update = tf.assign(v, new_v, use_locking=self._use_locking)
     updates = [v_update]
     x = grad * tf.rsqrt(new_v)
   if self._clipping_threshold is not None:
     clipping_denom = tf.maximum(1.0, reduce_rms(x) / self._clipping_threshold)
     x /= clipping_denom
   subtrahend = update_scale * x
   if self._beta1:
     m = self.get_slot(var, "m")
     new_m = self._beta1 * tf.to_float(m) + (1.0 - self._beta1) * subtrahend
     subtrahend = new_m
     new_m = common_layers.cast_like(new_m, var)
     updates.append(tf.assign(m, new_m, use_locking=self._use_locking))
   new_val = tf.to_float(old_val) - subtrahend
   if var.dtype.base_dtype == tf.bfloat16:
     new_val = self._parameter_encoding.encode(
         new_val, self._quantization_noise)
   if self._simulated_quantize_bits:
     new_val = quantization.simulated_quantize(
         var - subtrahend, self._simulated_quantize_bits,
         self._quantization_noise)
   var_update = tf.assign(var, new_val, use_locking=self._use_locking)
   updates = [var_update] + updates
   return tf.group(*updates)