コード例 #1
0
ファイル: layers.py プロジェクト: tspannhw/mesh
def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
    """Layer normalization over dimension dim.

  Args:
    x: a mtf.Tensor whose shape contains dim.
    dim: a mtf.Dimension
    epsilon: a floating point number
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
    with tf.variable_scope(name + "/layer_norm"):
        scale = mtf.get_variable(x.mesh,
                                 "layer_norm_scale",
                                 mtf.Shape([dim]),
                                 initializer=tf.ones_initializer(),
                                 activation_dtype=x.dtype)
        bias = mtf.get_variable(x.mesh,
                                "layer_norm_bias",
                                mtf.Shape([dim]),
                                initializer=tf.zeros_initializer(),
                                activation_dtype=x.dtype)
        reduced_shape = x.shape - dim
        mean = mtf.reduce_mean(x, output_shape=reduced_shape)
        variance = mtf.reduce_mean(mtf.square(x - mean),
                                   output_shape=reduced_shape)
        norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)
        return norm_x * scale + bias
コード例 #2
0
ファイル: layers.py プロジェクト: tspannhw/mesh
def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None):
    """Batch normalization.

  Args:
    x: a mtf.Tensor whose shape contains [batch_dim, ..., dim]
    is_training: a boolean, whether mode is training.
    momentum: a floating point number, specifying batch norm decay value.
    epsilon: a floating point number.
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
    with tf.variable_scope(name, default_name="batch_norm", values=[x]):
        batch_dim = x.shape.dims[0]
        reduced_shape = x.shape - batch_dim
        scale = mtf.get_variable(x.mesh,
                                 "batch_norm_scale",
                                 mtf.Shape([batch_dim]),
                                 initializer=tf.ones_initializer(),
                                 activation_dtype=x.dtype)
        bias = mtf.get_variable(x.mesh,
                                "batch_norm_bias",
                                mtf.Shape([batch_dim]),
                                initializer=tf.zeros_initializer(),
                                activation_dtype=x.dtype)

        moving_mean = mtf.get_variable(
            x.mesh,
            "moving_mean",
            reduced_shape,
            initializer=tf.random_normal_initializer(stddev=1.0),
            activation_dtype=x.dtype,
            trainable=False)
        moving_variance = mtf.get_variable(x.mesh,
                                           "moving_variance",
                                           reduced_shape,
                                           initializer=tf.ones_initializer(),
                                           activation_dtype=x.dtype,
                                           trainable=False)

        # At training time, calculate mean and variance and normalize across batch
        # dim.
        if is_training:
            mean = mtf.reduce_mean(x, output_shape=reduced_shape)
            variance = mtf.reduce_mean(mtf.square(x - mean),
                                       output_shape=reduced_shape)
            norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)

            # Update running mean and running variance.
            moving_mean = mtf.assign(
                moving_mean, momentum * moving_mean + (1 - momentum) * mean)
            moving_variance = mtf.assign(
                moving_variance,
                momentum * moving_variance + (1 - momentum) * variance)
        else:
            # At eval and test time, use the running mean and variance.
            norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon)
        return norm_x * scale + bias
コード例 #3
0
ファイル: layers.py プロジェクト: trantorznh/mesh
def compress_mean(x, dim, compression_factor):
    """Compress by taking group means.

  Args:
    x: a Tensor
    dim: a dimension in x.shape
    compression_factor: an integer

  Returns:
    a Tensor
  """
    dims = x.shape.dims
    pos = dims.index(dim)
    compressed_dim = mtf.Dimension(dim.name, dim.size // compression_factor)
    compression_factor_dim = mtf.Dimension("compression_factor",
                                           compression_factor)
    new_shape = (dims[:pos] + [compressed_dim, compression_factor_dim] +
                 dims[pos + 1:])
    x = mtf.reshape(x, new_shape)
    x = mtf.reduce_mean(x, reduced_dim=compression_factor_dim)
    return x
コード例 #4
0
ファイル: optimize.py プロジェクト: ml-lab/mesh
def reduce_rms(x):
    return mtf.sqrt(mtf.reduce_mean(mtf.square(x)))
コード例 #5
0
ファイル: optimize.py プロジェクト: ml-lab/mesh
    def apply_grad(self, grad, var):
        if grad is None:
            tf.logging.warning("Gradient is None for variable %s" % var)
            return []
        # create slots
        grad = mtf.to_float(grad)
        factored_dims = self._factored_dims(var.shape)
        if factored_dims:
            d0, d1 = factored_dims
            vr_shape = var.shape - d0
            vc_shape = var.shape - d1
            vr = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vr",
                                  vr_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
            vc = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vc",
                                  vc_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
        else:
            v = mtf.get_variable(var.mesh,
                                 var.name + "_slot_v",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)
        if self._beta1:
            m = mtf.get_variable(var.mesh,
                                 var.name + "_slot_m",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)

        with tf.variable_scope(var.name + "/adafactor"):
            grad_squared = mtf.square(grad) + self._epsilon1
            decay_rate = self._decay_rate
            old_val = mtf.to_float(var.value)
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(
                    old_val) * self._learning_rate
            else:
                update_scale = self._learning_rate
            mixing_rate = 1.0 - decay_rate
            updates = []
            if factored_dims:
                grad_squared_row_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vr_shape)
                grad_squared_col_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vc_shape)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                vr_update = mtf.assign(vr, new_vr)
                vc_update = mtf.assign(vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1)
                r_factor = mtf.rsqrt(new_vr / long_term_mean)
                c_factor = mtf.rsqrt(new_vc)
                x = grad * r_factor * c_factor
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                v_update = mtf.assign(v, new_v)
                updates.append(v_update)
                x = grad * mtf.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = mtf.maximum(
                    1.0,
                    reduce_rms(x) / self._clipping_threshold)
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1) +
                         subtrahend * tf.constant(1.0 - self._beta1))
                subtrahend = new_m
                updates.append(mtf.assign(m, new_m))
            var_update = mtf.assign_sub(var, subtrahend)
            updates.append(var_update)
            return updates