Ejemplo n.º 1
0
def multihead_attention_vars(
    mesh, heads, io_channels, kv_channels, activation_dtype):
  """Create Parameters for Multihead Attention.

  Args:
    mesh: a Mesh
    heads: a Dimension
    io_channels: a Dimension
    kv_channels: a Dimension
    activation_dtype: a tf.dtype

  Returns:
    q_var: a Tensor with shape [heads, io_channels, kv_channels]
    k_var: a Tensor with shape [heads, io_channels, kv_channels]
    v_var: a Tensor with shape [heads, io_channels, kv_channels]
    o_var: a Tensor with shape [heads, io_channels, kv_channels]
  """
  qkvo = mtf.Dimension("qkvo", 4)
  qk_stddev = (io_channels.size ** -0.5) * (kv_channels.size ** -0.25)
  v_stddev = io_channels.size ** -0.5
  o_stddev = (io_channels.size * heads.size) ** -0.5
  def qkvo_initializer(shape,
                       dtype=None,
                       partition_info=None,
                       verify_shape=None):
    del partition_info, verify_shape
    return tf.random_normal(shape, dtype=dtype) * tf.reshape(
        [qk_stddev, qk_stddev, v_stddev, o_stddev], [4, 1, 1, 1])
  var = mtf.get_variable(
      mesh, "qkvo", mtf.Shape([qkvo, heads, io_channels, kv_channels]),
      initializer=qkvo_initializer, activation_dtype=activation_dtype)
  q_var, k_var, v_var, o_var = mtf.unstack(var, qkvo)
  return q_var, k_var, v_var, o_var
Ejemplo n.º 2
0
def dense_relu_dense(x,
                     hidden_channels,
                     dropout=0.0,
                     dropout_broadcast_dims=None,
                     name=None):
  """Hidden layer with ReLU activation followed by linear projection.

  The output has the same number of channels as the input.

  Args:
    x: a mtf.Tensor
    hidden_channels: a mtf.Dimension - channels in the hidden layer
    dropout: an optional float
    dropout_broadcast_dims: an optional list of mtf.Dimension
    name: an optional string

  Returns:
    a mtf.Tensor with the same shape as x.
  """
  with tf.variable_scope(name, default_name="dense_relu_dense"):
    io_channels = x.shape.dims[-1]
    stddev = (hidden_channels.size * io_channels.size) ** -0.25
    io = mtf.Dimension("io", 2)
    w = mtf.get_variable(
        x.mesh,
        "kernel",
        mtf.Shape([io, io_channels, hidden_channels]),
        initializer=tf.random_normal_initializer(stddev=stddev),
        activation_dtype=x.dtype)
    wi, wo = mtf.unstack(w, io)
    h = mtf.relu(mtf.einsum([x, wi]))
    if dropout != 0.0:
      h = mtf.dropout(h, 1.0 - dropout,
                      noise_shape=h.shape - dropout_broadcast_dims)
    return mtf.einsum([h, wo])
Ejemplo n.º 3
0
  def _decoder_layer_stack_incremental(self,
                                       x,
                                       step_num,
                                       encdec_tensors,
                                       self_attention_k,
                                       self_attention_v,
                                       encdec_attention_mask=None):
    """Decoder layer stack during inference.

    We are processing only one position at a time.

    The self-attention keys and values have already been computed for
    previous positions.  In addition to the decoder output, we need to
    produce the updated self-attention keys and values.

    If there is an encoder, then additional Tensors are supplied in
    encdec_tensors, which give us the keys and values for encoder-decoder
    attention as well as the weight matrices q_var and o_var.

    Args:
      x: a mtf.Tensor with shape [batch_dim, model_dim]
      step_num: an mtf integer Scalar
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v)
      self_attention_k: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      self_attention_v: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.

    Returns:
      y: a mtf.Tensor with shape [batch_dim, model_dim]
      new_self_attention_k: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_k
      new_self_attention_v: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_v

    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    num_layers = hparams.num_decoder_layers
    num_layer_norms = num_layers * (2 if encdec_tensors is None else 3) + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    new_self_attention_k = []
    new_self_attention_v = []
    for layer in range(num_layers):
      with tf.variable_scope("layer_%d" % layer):
        # Self attention layer
        y, new_k, new_v = mtf_layers.multihead_self_attention_incremental(
            normalize(x),
            prev_k=self_attention_k[layer],
            prev_v=self_attention_v[layer],
            step_num=step_num,
            name="self_attention")
        new_self_attention_k.append(new_k)
        new_self_attention_v.append(new_v)
        x += y
        if encdec_tensors is not None:
          # Encoder-Decoder attention layer
          q_var, o_var, k, v = encdec_tensors[layer]
          x += mtf_layers.multihead_encdec_attention_incremental(
              normalize(x),
              q_var, o_var, k, v,
              encdec_attention_mask,
              name="encdec_attention")
        # ffn layer
        x += self._feedforward_layer(normalize(x), hparams)
    x = normalize(x)
    assert not layer_norm_vars
    return x, new_self_attention_k, new_self_attention_v
Ejemplo n.º 4
0
  def _layer_stack(self,
                   x,
                   num_layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [batch_dim, length_dim, model_dim]
      num_layers: an integer
      encoder_output: an optional mtf.Tensor with shape
        [batch_dim, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
    Returns:
      a mtf.Tensor with shape [batch_dim, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))
    num_layer_norms = num_layers * (2 if encoder_output is None else 3) + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    for layer in range(num_layers):
      with tf.variable_scope("layer_%d" % layer):
        # Self attention layer
        x += layer_prepostprocess_dropout(
            mtf_layers.multihead_attention(
                normalize(x), None,
                self_attention_mask, self.kv_dim, self.heads_dim,
                dropout=hparams.attention_dropout,
                dropout_broadcast_dims=[self.length_dim],
                name="self_attention"))
        if encoder_output is not None:
          # Encoder-Decoder attention layer
          x += layer_prepostprocess_dropout(
              mtf_layers.multihead_attention(
                  normalize(x), encoder_output,
                  encdec_attention_mask, self.kv_dim, self.heads_dim,
                  dropout=hparams.attention_dropout,
                  dropout_broadcast_dims=[self.length_dim],
                  name="encdec_attention"))
        # ffn layer
        x += layer_prepostprocess_dropout(
            self._feedforward_layer(normalize(x), losses=losses))
    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    return x