def _parameter_scale(self, var): """Estimate the scale of the parameters from the current values. We include a minimum value of 0.001 to give it a chance to escape 0 if it was zero-initialized. Instead of using the value, we could impute the scale from the shape, as initializers do. Args: var: a variable or Tensor. Returns: a Scalar """ return mtf.maximum(reduce_rms(var), self._epsilon2)
def apply_grad(self, grad, var): if grad is None: tf.logging.warning("Gradient is None for variable %s" % var) return [] # create slots grad = mtf.to_float(grad) factored_dims = self._factored_dims(var.shape) if factored_dims: d0, d1 = factored_dims vr_shape = var.shape - d0 vc_shape = var.shape - d1 vr = mtf.get_variable(var.mesh, var.name + "_slot_vr", vr_shape, initializer=tf.zeros_initializer(), trainable=False) vc = mtf.get_variable(var.mesh, var.name + "_slot_vc", vc_shape, initializer=tf.zeros_initializer(), trainable=False) else: v = mtf.get_variable(var.mesh, var.name + "_slot_v", var.shape, initializer=tf.zeros_initializer(), trainable=False) if self._beta1: m = mtf.get_variable(var.mesh, var.name + "_slot_m", var.shape, initializer=tf.zeros_initializer(), trainable=False) with tf.variable_scope(var.name + "/adafactor"): grad_squared = mtf.square(grad) + self._epsilon1 decay_rate = self._decay_rate old_val = mtf.to_float(var.value) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale( old_val) * self._learning_rate else: update_scale = self._learning_rate mixing_rate = 1.0 - decay_rate updates = [] if factored_dims: grad_squared_row_mean = mtf.reduce_mean(grad_squared, output_shape=vr_shape) grad_squared_col_mean = mtf.reduce_mean(grad_squared, output_shape=vc_shape) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate vr_update = mtf.assign(vr, new_vr) vc_update = mtf.assign(vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1) r_factor = mtf.rsqrt(new_vr / long_term_mean) c_factor = mtf.rsqrt(new_vc) x = grad * r_factor * c_factor else: new_v = v * decay_rate + grad_squared * mixing_rate v_update = mtf.assign(v, new_v) updates.append(v_update) x = grad * mtf.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = mtf.maximum( 1.0, reduce_rms(x) / self._clipping_threshold) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1) + subtrahend * tf.constant(1.0 - self._beta1)) subtrahend = new_m updates.append(mtf.assign(m, new_m)) # It is critical to use assign_sub instead of mtf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. var_update = mtf.assign_sub(var, subtrahend) updates.append(var_update) return updates