Пример #1
0
def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
    """Layer normalization over dimension dim.

  Args:
    x: a mtf.Tensor whose shape contains dim.
    dim: a mtf.Dimension
    epsilon: a floating point number
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
    with tf.variable_scope(name + "/layer_norm"):
        scale = mtf.get_variable(x.mesh,
                                 "layer_norm_scale",
                                 mtf.Shape([dim]),
                                 initializer=tf.ones_initializer(),
                                 activation_dtype=x.dtype)
        bias = mtf.get_variable(x.mesh,
                                "layer_norm_bias",
                                mtf.Shape([dim]),
                                initializer=tf.zeros_initializer(),
                                activation_dtype=x.dtype)
        reduced_shape = x.shape - dim
        mean = mtf.reduce_mean(x, output_shape=reduced_shape)
        variance = mtf.reduce_mean(mtf.square(x - mean),
                                   output_shape=reduced_shape)
        norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)
        return norm_x * scale + bias
Пример #2
0
def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None):
    """Batch normalization.

  Args:
    x: a mtf.Tensor whose shape contains [batch_dim, ..., dim]
    is_training: a boolean, whether mode is training.
    momentum: a floating point number, specifying batch norm decay value.
    epsilon: a floating point number.
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
    with tf.variable_scope(name, default_name="batch_norm", values=[x]):
        batch_dim = x.shape.dims[0]
        reduced_shape = x.shape - batch_dim
        scale = mtf.get_variable(x.mesh,
                                 "batch_norm_scale",
                                 mtf.Shape([batch_dim]),
                                 initializer=tf.ones_initializer(),
                                 activation_dtype=x.dtype)
        bias = mtf.get_variable(x.mesh,
                                "batch_norm_bias",
                                mtf.Shape([batch_dim]),
                                initializer=tf.zeros_initializer(),
                                activation_dtype=x.dtype)

        moving_mean = mtf.get_variable(
            x.mesh,
            "moving_mean",
            reduced_shape,
            initializer=tf.random_normal_initializer(stddev=1.0),
            activation_dtype=x.dtype,
            trainable=False)
        moving_variance = mtf.get_variable(x.mesh,
                                           "moving_variance",
                                           reduced_shape,
                                           initializer=tf.ones_initializer(),
                                           activation_dtype=x.dtype,
                                           trainable=False)

        # At training time, calculate mean and variance and normalize across batch
        # dim.
        if is_training:
            mean = mtf.reduce_mean(x, output_shape=reduced_shape)
            variance = mtf.reduce_mean(mtf.square(x - mean),
                                       output_shape=reduced_shape)
            norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)

            # Update running mean and running variance.
            moving_mean = mtf.assign(
                moving_mean, momentum * moving_mean + (1 - momentum) * mean)
            moving_variance = mtf.assign(
                moving_variance,
                momentum * moving_variance + (1 - momentum) * variance)
        else:
            # At eval and test time, use the running mean and variance.
            norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon)
        return norm_x * scale + bias
Пример #3
0
def dense(x,
          output_dim,
          reduced_dims=None,
          expert_dims=None,
          use_bias=True,
          activation=None,
          master_dtype=tf.float32,
          slice_dtype=tf.float32,
          name=None):
    """Dense layer doing (kernel*x + bias) computation.

  Args:
    x: a mtf.Tensor of shape [..., reduced_dims].
    output_dim: a mtf.Dimension
    reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If
      omitted, we reduce the last dimension.
    expert_dims: an optional list of mtf.Dimension which represent different
      experts. Different experts get different weights.
    use_bias: a boolean, whether to add bias.
    activation: an optional function from mtf.Tensor to mtf.Tensor
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    name: a string. variable scope.

  Returns:
    a mtf.Tensor of shape [..., output_dim].
  """
    if expert_dims is None:
        expert_dims = []
    if reduced_dims is None:
        reduced_dims = x.shape.dims[-1:]
    w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim])
    output_shape = mtf.Shape(
        [d for d in x.shape.dims if d not in reduced_dims] + [output_dim])

    with tf.variable_scope(name, default_name="dense"):
        stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5
        w = mtf.get_variable(
            x.mesh,
            "kernel",
            w_shape,
            initializer=tf.random_normal_initializer(stddev=stddev),
            master_dtype=master_dtype,
            slice_dtype=slice_dtype,
            activation_dtype=x.dtype)
        y = mtf.einsum([x, w], output_shape)
        if use_bias:
            b = mtf.get_variable(x.mesh,
                                 "bias",
                                 mtf.Shape(expert_dims + [output_dim]),
                                 initializer=tf.zeros_initializer(),
                                 activation_dtype=x.dtype)
            y += b
        if activation is not None:
            y = activation(y)
        return y
Пример #4
0
def multihead_attention_vars(mesh, heads, io_channels, kv_channels,
                             master_dtype, slice_dtype, activation_dtype):
    """Create Parameters for Multihead Attention.

  Args:
    mesh: a Mesh
    heads: a Dimension
    io_channels: a Dimension
    kv_channels: a Dimension
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    activation_dtype: a tf.dtype

  Returns:
    q_var: a Tensor with shape [heads, io_channels, kv_channels]
    k_var: a Tensor with shape [heads, io_channels, kv_channels]
    v_var: a Tensor with shape [heads, io_channels, kv_channels]
    o_var: a Tensor with shape [heads, io_channels, kv_channels]
  """
    qkvo = mtf.Dimension("qkvo", 4)
    qk_stddev = (io_channels.size**-0.5) * (kv_channels.size**-0.25)
    v_stddev = io_channels.size**-0.5
    o_stddev = (io_channels.size * heads.size)**-0.5

    def qkvo_initializer(shape,
                         dtype=None,
                         partition_info=None,
                         verify_shape=None):
        del partition_info, verify_shape
        return tf.random_normal(shape, dtype=dtype) * tf.reshape(
            tf.cast([qk_stddev, qk_stddev, v_stddev, o_stddev], dtype
                    or tf.float32), [4, 1, 1, 1])

    var = mtf.get_variable(mesh,
                           "qkvo",
                           mtf.Shape([qkvo, heads, io_channels, kv_channels]),
                           initializer=qkvo_initializer,
                           master_dtype=master_dtype,
                           slice_dtype=slice_dtype,
                           activation_dtype=activation_dtype)
    q_var, k_var, v_var, o_var = mtf.unstack(var, qkvo)
    return q_var, k_var, v_var, o_var
Пример #5
0
def dense_relu_dense(x,
                     hidden_channels,
                     dropout=0.0,
                     dropout_broadcast_dims=None,
                     name=None):
    """Hidden layer with ReLU activation followed by linear projection.

  The output has the same number of channels as the input.

  Args:
    x: a mtf.Tensor
    hidden_channels: a mtf.Dimension - channels in the hidden layer
    dropout: an optional float
    dropout_broadcast_dims: an optional list of mtf.Dimension
    name: an optional string

  Returns:
    a mtf.Tensor with the same shape as x.
  """
    with tf.variable_scope(name, default_name="dense_relu_dense"):
        io_channels = x.shape.dims[-1]
        stddev = (hidden_channels.size * io_channels.size)**-0.25
        io = mtf.Dimension("io", 2)
        w = mtf.get_variable(
            x.mesh,
            "kernel",
            mtf.Shape([io, io_channels, hidden_channels]),
            initializer=tf.random_normal_initializer(stddev=stddev),
            activation_dtype=x.dtype)
        wi, wo = mtf.unstack(w, io)
        h = mtf.relu(mtf.einsum([x, wi]))
        if dropout != 0.0:
            h = mtf.dropout(h,
                            1.0 - dropout,
                            noise_shape=h.shape - dropout_broadcast_dims)
        return mtf.einsum([h, wo])
Пример #6
0
    def apply_grad(self, grad, var):
        if grad is None:
            tf.logging.warning("Gradient is None for variable %s" % var)
            return []
        # create slots
        grad = mtf.to_float(grad)
        factored_dims = self._factored_dims(var.shape)
        if factored_dims:
            d0, d1 = factored_dims
            vr_shape = var.shape - d0
            vc_shape = var.shape - d1
            vr = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vr",
                                  vr_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
            vc = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vc",
                                  vc_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
        else:
            v = mtf.get_variable(var.mesh,
                                 var.name + "_slot_v",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)
        if self._beta1:
            m = mtf.get_variable(var.mesh,
                                 var.name + "_slot_m",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)

        with tf.variable_scope(var.name + "/adafactor"):
            grad_squared = mtf.square(grad) + self._epsilon1
            decay_rate = self._decay_rate
            old_val = mtf.to_float(var.value)
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(
                    old_val) * self._learning_rate
            else:
                update_scale = self._learning_rate
            mixing_rate = 1.0 - decay_rate
            updates = []
            if factored_dims:
                grad_squared_row_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vr_shape)
                grad_squared_col_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vc_shape)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                vr_update = mtf.assign(vr, new_vr)
                vc_update = mtf.assign(vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1)
                r_factor = mtf.rsqrt(new_vr / long_term_mean)
                c_factor = mtf.rsqrt(new_vc)
                x = grad * r_factor * c_factor
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                v_update = mtf.assign(v, new_v)
                updates.append(v_update)
                x = grad * mtf.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = mtf.maximum(
                    1.0,
                    reduce_rms(x) / self._clipping_threshold)
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1) +
                         subtrahend * tf.constant(1.0 - self._beta1))
                subtrahend = new_m
                updates.append(mtf.assign(m, new_m))
            var_update = mtf.assign_sub(var, subtrahend)
            updates.append(var_update)
            return updates