def _parameter_scale(self, var): """Estimate the scale of the parameters from the current values. We include a minimum value of 0.001 to give it a chance to escape 0 if it was zero-initialized. Instead of using the value, we could impute the scale from the shape, as initializers do. Args: var: a variable or Tensor. Returns: a Scalar """ return tf.math.maximum(py_utils.ReduceRms(var), tf.constant(self._epsilon2, var.dtype))
def try_apply_dense(self, grad, var): assert grad is not None cond = tf.constant(True) is_finite_checks = [] stats = {} grad_dtype = var.dtype # TODO(lepikhin): add to params grad = tf.cast(grad, grad_dtype) factored_dims = self._factored_dims(var.shape.as_list()) if factored_dims: vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') else: v = self.get_slot(var, 'v') if self._beta1: m = self.get_slot(var, 'm') def _Upd(c, k, x): stats[k] = x is_finite_checks.append(tf.reduce_all(tf.math.is_finite(x))) return c with tf.variable_scope(var.name[:-2] + '/Adafactor'): grad_squared = tf.math.square(grad) + tf.cast( self._epsilon1, grad_dtype) cond = _Upd(cond, 'grad_squared', grad_squared) # 0 (factored) decay_rate = tf.cast(self._decay_rate, var.dtype) old_val = tf.identity( var) # TODO(lepikhin): introduce gradient dtype assert self._multiply_by_parameter_scale lr = GetLrValue(self._learning_rate) if self._multiply_by_parameter_scale: parameter_scale = self._parameter_scale(old_val) cond = _Upd(cond, 'parameter_scale', parameter_scale) # 1 (factored) update_scale = self._parameter_scale(old_val) * tf.cast( lr, grad_dtype) else: update_scale = lr mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype) update_scale = tf.cast(update_scale, grad_dtype) if factored_dims: d0, d1 = factored_dims vr_axis, vc_axis = d0, d1 grad_squared_row_mean = tf.reduce_mean(grad_squared, axis=vr_axis) grad_squared_col_mean = tf.reduce_mean(grad_squared, axis=vc_axis) # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate cond = _Upd(cond, 'new_vr', new_vr) # 2 (factored) cond = _Upd(cond, 'new_vc', new_vc) # 3 (factored) # vr_update = _Wrap(tf.assign, vr, new_vr) # vc_update = _Wrap(tf.assign, vc, new_vc) # updates.extend([vr_update, vc_update]) long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.math.rsqrt(new_vr / long_term_mean) c_factor = tf.math.rsqrt(new_vc) mult = tf.expand_dims(r_factor, vr_axis) * tf.expand_dims( c_factor, vc_axis) cond = _Upd(cond, 'mult', mult) # 4 (factored) x = grad * mult else: new_v = v * decay_rate + grad_squared * mixing_rate cond = _Upd(cond, 'new_v', new_v) # v_update = _Wrap(tf.assign, v, new_v) # updates.append(v_update) x = grad * tf.math.rsqrt(new_v) assert self._clipping_threshold is not None if self._clipping_threshold is not None: clipping_denom = tf.maximum( tf.constant(1.0, grad_dtype), py_utils.ReduceRms(x) / tf.constant(self._clipping_threshold, grad_dtype)) x /= clipping_denom cond = _Upd(cond, 'x', x) subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) + subtrahend * tf.constant(1.0 - self._beta1, dtype=grad_dtype)) subtrahend = new_m cond = _Upd(cond, 'new_m', new_m) # updates.append(_Wrap(tf.assign, m, new_m)) # It is critical to use assign_sub instead of tf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. cond = _Upd(cond, 'subtrahend', subtrahend) # 5 (factored) # var_update = _Wrap(tf.assign_sub, var, subtrahend) # updates.append(var_update) return is_finite_checks, stats
def _resource_apply_dense(self, grad, var): if grad is None: tf.logging.warning('Gradient is None for variable %s' % var.name) return [] grad_dtype = var.dtype # TODO(lepikhin): add to params grad = tf.cast(grad, grad_dtype) factored_dims = self._factored_dims(var.shape.as_list()) if factored_dims: vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') else: v = self.get_slot(var, 'v') if self._beta1: m = self.get_slot(var, 'm') cond = tf.constant(True) def _Upd(c, x): if not self._cond_is_finite: return c c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x))) c = tf.math.logical_and( c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x)))) return c def _Wrap(fn, x, y): if not self._cond_is_finite: return fn(x, y) return tf.cond(cond, lambda: fn(x, y), lambda: x) with tf.variable_scope(var.name[:-2] + '/Adafactor'): grad_squared = tf.math.square(grad) + tf.cast( self._epsilon1, grad_dtype) cond = _Upd(cond, grad_squared) decay_rate = tf.cast(self._decay_rate, var.dtype) old_val = tf.identity( var) # TODO(lepikhin): introduce gradient dtype lr = GetLrValue(self._learning_rate) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale(old_val) * tf.cast( lr, grad_dtype) else: update_scale = lr mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype) update_scale = tf.cast(update_scale, grad_dtype) updates = [] if factored_dims: d0, d1 = factored_dims vr_axis, vc_axis = d0, d1 grad_squared_row_mean = tf.reduce_mean(grad_squared, axis=vr_axis) grad_squared_col_mean = tf.reduce_mean(grad_squared, axis=vc_axis) # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate cond = _Upd(cond, new_vr) cond = _Upd(cond, new_vc) vr_update = _Wrap(tf.assign, vr, new_vr) vc_update = _Wrap(tf.assign, vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.math.rsqrt(new_vr / long_term_mean) c_factor = tf.math.rsqrt(new_vc) x = grad * tf.expand_dims(r_factor, vr_axis) * tf.expand_dims( c_factor, vc_axis) else: new_v = v * decay_rate + grad_squared * mixing_rate cond = _Upd(cond, new_v) v_update = _Wrap(tf.assign, v, new_v) updates.append(v_update) x = grad * tf.math.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = tf.maximum( tf.constant(1.0, grad_dtype), py_utils.ReduceRms(x) / tf.constant(self._clipping_threshold, grad_dtype)) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) + subtrahend * tf.constant(1.0 - self._beta1, dtype=grad_dtype)) subtrahend = new_m cond = _Upd(cond, new_m) updates.append(_Wrap(tf.assign, m, new_m)) # It is critical to use assign_sub instead of tf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. cond = _Upd(cond, subtrahend) var_update = _Wrap(tf.assign_sub, var, subtrahend) updates.append(var_update) return tf.group(*updates)