コード例 #1
0
ファイル: layers.py プロジェクト: qixiuai/mesh
def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
  """Layer normalization over dimension dim.

  Args:
    x: a mtf.Tensor whose shape contains dim.
    dim: a mtf.Dimension
    epsilon: a floating point number
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
  with tf.variable_scope(name + "/layer_norm"):
    scale = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    bias = mtf.get_variable(
        x.mesh,
        "layer_norm_bias",
        mtf.Shape([dim]),
        initializer=tf.zeros_initializer(),
        activation_dtype=x.dtype)
    reduced_shape = x.shape - dim
    mean = mtf.reduce_mean(x, output_shape=reduced_shape)
    variance = mtf.reduce_mean(mtf.square(x - mean), output_shape=reduced_shape)
    norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)
    return norm_x * scale + bias
コード例 #2
0
ファイル: layers.py プロジェクト: qixiuai/mesh
def compress_mean(x, dim, compression_factor):
  """Compress by taking group means.

  Args:
    x: a Tensor
    dim: a dimension in x.shape
    compression_factor: an integer

  Returns:
    a Tensor
  """
  dims = x.shape.dims
  pos = dims.index(dim)
  compressed_dim = mtf.Dimension(dim.name, dim.size // compression_factor)
  compression_factor_dim = mtf.Dimension(
      "compression_factor", compression_factor)
  new_shape = (
      dims[:pos] + [compressed_dim, compression_factor_dim] + dims[pos + 1:])
  x = mtf.reshape(x, new_shape)
  x = mtf.reduce_mean(x, reduced_dim=compression_factor_dim)
  return x
コード例 #3
0
def reduce_rms(x):
    return mtf.sqrt(mtf.reduce_mean(mtf.square(x)))
コード例 #4
0
    def apply_grad(self, grad, var):
        if grad is None:
            tf.logging.warning("Gradient is None for variable %s" % var)
            return []
        # create slots
        grad = mtf.to_float(grad)
        factored_dims = self._factored_dims(var.shape)
        if factored_dims:
            d0, d1 = factored_dims
            vr_shape = var.shape - d0
            vc_shape = var.shape - d1
            vr = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vr",
                                  vr_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
            vc = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vc",
                                  vc_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
        else:
            v = mtf.get_variable(var.mesh,
                                 var.name + "_slot_v",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)
        if self._beta1:
            m = mtf.get_variable(var.mesh,
                                 var.name + "_slot_m",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)

        with tf.variable_scope(var.name + "/adafactor"):
            grad_squared = mtf.square(grad) + self._epsilon1
            decay_rate = self._decay_rate
            old_val = mtf.to_float(var.value)
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(
                    old_val) * self._learning_rate
            else:
                update_scale = self._learning_rate
            mixing_rate = 1.0 - decay_rate
            updates = []
            if factored_dims:
                grad_squared_row_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vr_shape)
                grad_squared_col_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vc_shape)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                vr_update = mtf.assign(vr, new_vr)
                vc_update = mtf.assign(vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1)
                r_factor = mtf.rsqrt(new_vr / long_term_mean)
                c_factor = mtf.rsqrt(new_vc)
                x = grad * r_factor * c_factor
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                v_update = mtf.assign(v, new_v)
                updates.append(v_update)
                x = grad * mtf.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = mtf.maximum(
                    1.0,
                    reduce_rms(x) / self._clipping_threshold)
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1) +
                         subtrahend * tf.constant(1.0 - self._beta1))
                subtrahend = new_m
                updates.append(mtf.assign(m, new_m))
            # It is critical to use assign_sub instead of mtf.assign(var - subtrahend)
            #  for the case of bfloat16 activations, so as to avoid repeatedly
            #  rounding the slice value, which results in poor quality.
            var_update = mtf.assign_sub(var, subtrahend)
            updates.append(var_update)
            return updates
コード例 #5
0
def reduce_rms(x, **kwargs):
  return mtf.sqrt(mtf.reduce_mean(mtf.square(x), **kwargs))
コード例 #6
0
ファイル: layers.py プロジェクト: qixiuai/mesh
def batch_norm(x, is_training, momentum, epsilon=1e-9,
               init_zero=False, name=None):
  """Batch normalization.

  Args:
    x: a mtf.Tensor whose shape contains [batch_dim, ..., dim]
    is_training: a boolean, whether mode is training.
    momentum: a floating point number, specifying batch norm decay value.
    epsilon: a floating point number.
    init_zero: a boolean, whether to initialize scale with 0's or 1's.
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
  with tf.variable_scope(name, default_name="batch_norm", values=[x]):
    if init_zero:
      gamma_initializer = tf.zeros_initializer()
    else:
      gamma_initializer = tf.ones_initializer()

    norm_dim = x.shape.dims[0:3]
    reduced_shape = x.shape - norm_dim

    scale = mtf.get_variable(
        x.mesh,
        "batch_norm_scale",
        reduced_shape,
        initializer=gamma_initializer,
        activation_dtype=x.dtype)
    bias = mtf.get_variable(
        x.mesh,
        "batch_norm_bias",
        reduced_shape,
        initializer=tf.zeros_initializer(),
        activation_dtype=x.dtype)

    moving_mean = mtf.get_variable(
        x.mesh, "moving_mean", reduced_shape,
        initializer=tf.random_normal_initializer(stddev=1.0),
        activation_dtype=x.dtype,
        trainable=False)
    moving_variance = mtf.get_variable(
        x.mesh, "moving_variance",
        reduced_shape, initializer=tf.ones_initializer(),
        activation_dtype=x.dtype,
        trainable=False)

    # At training time, calculate mean and variance and normalize across batch
    # dim.
    if is_training:
      mean = mtf.reduce_mean(x, output_shape=reduced_shape)
      variance = mtf.reduce_mean(
          mtf.square(x - mean), output_shape=reduced_shape)

      norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)

      # Update running mean and running variance.
      moving_mean = mtf.assign(
          moving_mean, momentum * moving_mean + (1-momentum) * mean)
      moving_variance = mtf.assign(
          moving_variance,
          momentum * moving_variance + (1 - momentum) * variance)
    else:
      # At eval and test time, use the running mean and variance.
      norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon)

    return (norm_x * scale) + bias