예제 #1
0
    def _parameter_scale(self, var):
        """Estimate the scale of the parameters from the current values.

    We include a minimum value of 0.001 to give it a chance to escape 0
    if it was zero-initialized.

    Instead of using the value, we could impute the scale from the shape,
    as initializers do.

    Args:
      var: a variable or Tensor.

    Returns:
      a Scalar
    """
        return tf.math.maximum(py_utils.ReduceRms(var),
                               tf.constant(self._epsilon2, var.dtype))
예제 #2
0
    def try_apply_dense(self, grad, var):
        assert grad is not None

        cond = tf.constant(True)
        is_finite_checks = []
        stats = {}

        grad_dtype = var.dtype  # TODO(lepikhin): add to params
        grad = tf.cast(grad, grad_dtype)
        factored_dims = self._factored_dims(var.shape.as_list())
        if factored_dims:
            vr = self.get_slot(var, 'vr')
            vc = self.get_slot(var, 'vc')
        else:
            v = self.get_slot(var, 'v')
        if self._beta1:
            m = self.get_slot(var, 'm')

        def _Upd(c, k, x):
            stats[k] = x
            is_finite_checks.append(tf.reduce_all(tf.math.is_finite(x)))
            return c

        with tf.variable_scope(var.name[:-2] + '/Adafactor'):
            grad_squared = tf.math.square(grad) + tf.cast(
                self._epsilon1, grad_dtype)
            cond = _Upd(cond, 'grad_squared', grad_squared)  # 0 (factored)
            decay_rate = tf.cast(self._decay_rate, var.dtype)
            old_val = tf.identity(
                var)  # TODO(lepikhin): introduce gradient dtype
            assert self._multiply_by_parameter_scale
            lr = GetLrValue(self._learning_rate)
            if self._multiply_by_parameter_scale:
                parameter_scale = self._parameter_scale(old_val)
                cond = _Upd(cond, 'parameter_scale',
                            parameter_scale)  # 1 (factored)
                update_scale = self._parameter_scale(old_val) * tf.cast(
                    lr, grad_dtype)

            else:
                update_scale = lr
            mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype)
            update_scale = tf.cast(update_scale, grad_dtype)
            if factored_dims:
                d0, d1 = factored_dims
                vr_axis, vc_axis = d0, d1
                grad_squared_row_mean = tf.reduce_mean(grad_squared,
                                                       axis=vr_axis)
                grad_squared_col_mean = tf.reduce_mean(grad_squared,
                                                       axis=vc_axis)
                # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                cond = _Upd(cond, 'new_vr', new_vr)  # 2 (factored)
                cond = _Upd(cond, 'new_vc', new_vc)  # 3 (factored)
                # vr_update = _Wrap(tf.assign, vr, new_vr)
                # vc_update = _Wrap(tf.assign, vc, new_vc)
                # updates.extend([vr_update, vc_update])
                long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
                r_factor = tf.math.rsqrt(new_vr / long_term_mean)
                c_factor = tf.math.rsqrt(new_vc)
                mult = tf.expand_dims(r_factor, vr_axis) * tf.expand_dims(
                    c_factor, vc_axis)
                cond = _Upd(cond, 'mult', mult)  # 4 (factored)
                x = grad * mult
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                cond = _Upd(cond, 'new_v', new_v)
                # v_update = _Wrap(tf.assign, v, new_v)
                # updates.append(v_update)
                x = grad * tf.math.rsqrt(new_v)

            assert self._clipping_threshold is not None

            if self._clipping_threshold is not None:
                clipping_denom = tf.maximum(
                    tf.constant(1.0, grad_dtype),
                    py_utils.ReduceRms(x) /
                    tf.constant(self._clipping_threshold, grad_dtype))
                x /= clipping_denom
            cond = _Upd(cond, 'x', x)
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) +
                         subtrahend *
                         tf.constant(1.0 - self._beta1, dtype=grad_dtype))
                subtrahend = new_m
                cond = _Upd(cond, 'new_m', new_m)
                # updates.append(_Wrap(tf.assign, m, new_m))

            # It is critical to use assign_sub instead of tf.assign(var - subtrahend)
            #  for the case of bfloat16 activations, so as to avoid repeatedly
            #  rounding the slice value, which results in poor quality.
            cond = _Upd(cond, 'subtrahend', subtrahend)  # 5 (factored)

            # var_update = _Wrap(tf.assign_sub, var, subtrahend)
            # updates.append(var_update)

            return is_finite_checks, stats
예제 #3
0
    def _resource_apply_dense(self, grad, var):
        if grad is None:
            tf.logging.warning('Gradient is None for variable %s' % var.name)
            return []

        grad_dtype = var.dtype  # TODO(lepikhin): add to params
        grad = tf.cast(grad, grad_dtype)
        factored_dims = self._factored_dims(var.shape.as_list())
        if factored_dims:
            vr = self.get_slot(var, 'vr')
            vc = self.get_slot(var, 'vc')
        else:
            v = self.get_slot(var, 'v')
        if self._beta1:
            m = self.get_slot(var, 'm')

        cond = tf.constant(True)

        def _Upd(c, x):
            if not self._cond_is_finite:
                return c
            c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x)))
            c = tf.math.logical_and(
                c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x))))
            return c

        def _Wrap(fn, x, y):
            if not self._cond_is_finite:
                return fn(x, y)
            return tf.cond(cond, lambda: fn(x, y), lambda: x)

        with tf.variable_scope(var.name[:-2] + '/Adafactor'):
            grad_squared = tf.math.square(grad) + tf.cast(
                self._epsilon1, grad_dtype)
            cond = _Upd(cond, grad_squared)
            decay_rate = tf.cast(self._decay_rate, var.dtype)
            old_val = tf.identity(
                var)  # TODO(lepikhin): introduce gradient dtype
            lr = GetLrValue(self._learning_rate)
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(old_val) * tf.cast(
                    lr, grad_dtype)
            else:
                update_scale = lr
            mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype)
            update_scale = tf.cast(update_scale, grad_dtype)
            updates = []
            if factored_dims:
                d0, d1 = factored_dims
                vr_axis, vc_axis = d0, d1
                grad_squared_row_mean = tf.reduce_mean(grad_squared,
                                                       axis=vr_axis)
                grad_squared_col_mean = tf.reduce_mean(grad_squared,
                                                       axis=vc_axis)
                # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                cond = _Upd(cond, new_vr)
                cond = _Upd(cond, new_vc)
                vr_update = _Wrap(tf.assign, vr, new_vr)
                vc_update = _Wrap(tf.assign, vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
                r_factor = tf.math.rsqrt(new_vr / long_term_mean)
                c_factor = tf.math.rsqrt(new_vc)
                x = grad * tf.expand_dims(r_factor, vr_axis) * tf.expand_dims(
                    c_factor, vc_axis)
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                cond = _Upd(cond, new_v)
                v_update = _Wrap(tf.assign, v, new_v)
                updates.append(v_update)
                x = grad * tf.math.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = tf.maximum(
                    tf.constant(1.0, grad_dtype),
                    py_utils.ReduceRms(x) /
                    tf.constant(self._clipping_threshold, grad_dtype))
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) +
                         subtrahend *
                         tf.constant(1.0 - self._beta1, dtype=grad_dtype))
                subtrahend = new_m
                cond = _Upd(cond, new_m)
                updates.append(_Wrap(tf.assign, m, new_m))
            # It is critical to use assign_sub instead of tf.assign(var - subtrahend)
            #  for the case of bfloat16 activations, so as to avoid repeatedly
            #  rounding the slice value, which results in poor quality.
            cond = _Upd(cond, subtrahend)
            var_update = _Wrap(tf.assign_sub, var, subtrahend)
            updates.append(var_update)
            return tf.group(*updates)