def _resource_apply_dense(self, grad, handle): var = handle grad = tf.to_float(grad) grad_squared = tf.square(grad) + self._epsilon1 grad_squared_mean = tf.reduce_mean(grad_squared) decay_rate = self._call_if_callable(self._decay_rate) update_scale = self._call_if_callable(self._learning_rate) update_scale = tf.convert_to_tensor(update_scale, name="update_scale") update_scale = tf.cast(update_scale, grad_squared_mean.dtype.base_dtype) old_val = var if var.dtype.base_dtype == tf.bfloat16: old_val = tf.to_float(self._parameter_encoding.decode(old_val)) if self._multiply_by_parameter_scale: update_scale *= tf.to_float(self._parameter_scale(old_val)) # HACK: Make things dependent on grad. # This confounds the XLA rewriter and keeps it from fusing computations # across different variables. This fusion is a bad for HBM usage, since # it causes the gradients to persist in memory. decay_rate += grad_squared_mean * 1e-30 update_scale += grad_squared_mean * 1e-30 # END HACK mixing_rate = 1.0 - decay_rate shape = var.get_shape().as_list() updates = [] if self._should_use_factored_second_moment_estimate(shape): grad_squared_row_mean = tf.reduce_mean(grad_squared, -1) grad_squared_col_mean = tf.reduce_mean(grad_squared, -2) vr = self.get_slot(var, "vr") new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) vc = self.get_slot(var, "vc") new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking) vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking) updates = [vr_update, vc_update] long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.rsqrt(new_vr / long_term_mean) c_factor = tf.rsqrt(new_vc) x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims( c_factor, -2) else: v = self.get_slot(var, "v") new_v = decay_rate * v + mixing_rate * grad_squared v_update = tf.assign(v, new_v, use_locking=self._use_locking) updates = [v_update] x = grad * tf.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = tf.maximum( 1.0, reduce_rms(x) / self._clipping_threshold) x /= clipping_denom subtrahend = update_scale * x if self._beta1: m = self.get_slot(var, "m") new_m = self._beta1 * tf.to_float(m) + (1.0 - self._beta1) * subtrahend subtrahend = new_m new_m = common_layers.cast_like(new_m, var) updates.append(tf.assign(m, new_m, use_locking=self._use_locking)) new_val = tf.to_float(old_val) - subtrahend if var.dtype.base_dtype == tf.bfloat16: new_val = self._parameter_encoding.encode(new_val, self._quantization_noise) if self._simulated_quantize_bits: new_val = quantization.simulated_quantize( var - subtrahend, self._simulated_quantize_bits, self._quantization_noise) new_val = tf.cast(new_val, var.dtype) var_update = tf.assign(var, new_val, use_locking=self._use_locking) updates = [var_update] + updates return tf.group(*updates)
def _resource_apply_dense(self, grad, handle): var = handle grad = tf.to_float(grad) grad_squared = tf.square(grad) + self._epsilon1 grad_squared_mean = tf.reduce_mean(grad_squared) decay_rate = self._decay_rate update_scale = self._learning_rate old_val = var if var.dtype.base_dtype == tf.bfloat16: old_val = tf.to_float(self._parameter_encoding.decode(old_val)) if self._multiply_by_parameter_scale: update_scale *= tf.to_float(self._parameter_scale(old_val)) # HACK: Make things dependent on grad. # This confounds the XLA rewriter and keeps it from fusing computations # across different variables. This fusion is a bad for HBM usage, since # it causes the gradients to persist in memory. decay_rate += grad_squared_mean * 1e-30 update_scale += grad_squared_mean * 1e-30 # END HACK mixing_rate = 1.0 - decay_rate shape = var.get_shape().as_list() updates = [] if self._should_use_factored_second_moment_estimate(shape): grad_squared_row_mean = tf.reduce_mean(grad_squared, -1) grad_squared_col_mean = tf.reduce_mean(grad_squared, -2) vr = self.get_slot(var, "vr") new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) vc = self.get_slot(var, "vc") new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking) vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking) updates = [vr_update, vc_update] long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.rsqrt(new_vr / long_term_mean) c_factor = tf.rsqrt(new_vc) x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2) else: v = self.get_slot(var, "v") new_v = decay_rate * v + mixing_rate * grad_squared v_update = tf.assign(v, new_v, use_locking=self._use_locking) updates = [v_update] x = grad * tf.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = tf.maximum(1.0, reduce_rms(x) / self._clipping_threshold) x /= clipping_denom subtrahend = update_scale * x if self._beta1: m = self.get_slot(var, "m") new_m = self._beta1 * tf.to_float(m) + (1.0 - self._beta1) * subtrahend subtrahend = new_m new_m = common_layers.cast_like(new_m, var) updates.append(tf.assign(m, new_m, use_locking=self._use_locking)) new_val = tf.to_float(old_val) - subtrahend if var.dtype.base_dtype == tf.bfloat16: new_val = self._parameter_encoding.encode( new_val, self._quantization_noise) if self._simulated_quantize_bits: new_val = quantization.simulated_quantize( var - subtrahend, self._simulated_quantize_bits, self._quantization_noise) var_update = tf.assign(var, new_val, use_locking=self._use_locking) updates = [var_update] + updates return tf.group(*updates)