Ejemplo n.º 1
0
def multihead_encdec_attention_incremental(query_antecedent,
                                           q_var,
                                           o_var,
                                           k,
                                           v,
                                           mask,
                                           name="multihead_attention"):
    """Incremental attention over encoder (one decode step).

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  memory_dims is a subset of query_dims

  Args:
    query_antecedent: a mtf.Tensor with shape query_dims + [io_channels]
    q_var: a mtf.Tensor with shape [heads, io_channels, kv_channels]
    o_var: a mtf.Tensor with shape [heads, io_channels, kv_channels]
    k: memory_dims + [heads, memory_length, kv_channels]
    v: memory_dims + [heads, memory_length, kv_channels]
    mask: mask Tensor (see attention_mask())
    name: an optional string.

  Returns:
    A mtf.Tensor with shape [batch, qlen, io_channels]
  """
    heads, _, kv_channels = k.shape.dims[-3:]
    query_dims = query_antecedent.shape.dims[:-1]
    with tf.variable_scope(name, default_name="multihead_attention"):
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(query_dims + [heads, kv_channels]))
        o = dot_product_attention(q, k, v, mask)
        return mtf.einsum([o, o_var], query_antecedent.shape)
Ejemplo n.º 2
0
def dense_relu_dense(x,
                     hidden_channels,
                     dropout=0.0,
                     dropout_broadcast_dims=None,
                     name=None):
  """Hidden layer with ReLU activation followed by linear projection.

  The output has the same number of channels as the input.

  Args:
    x: a mtf.Tensor
    hidden_channels: a mtf.Dimension - channels in the hidden layer
    dropout: an optional float
    dropout_broadcast_dims: an optional list of mtf.Dimension
    name: an optional string

  Returns:
    a mtf.Tensor with the same shape as x.
  """
  with tf.variable_scope(name, default_name="dense_relu_dense"):
    io_channels = x.shape.dims[-1]
    stddev = (hidden_channels.size * io_channels.size) ** -0.25
    io = mtf.Dimension("io", 2)
    w = mtf.get_variable(
        x.mesh,
        "kernel",
        mtf.Shape([io, io_channels, hidden_channels]),
        initializer=tf.random_normal_initializer(stddev=stddev),
        activation_dtype=x.dtype)
    wi, wo = mtf.unstack(w, io)
    h = mtf.relu(mtf.einsum([x, wi]))
    if dropout != 0.0:
      h = mtf.dropout(h, 1.0 - dropout,
                      noise_shape=h.shape - dropout_broadcast_dims)
    return mtf.einsum([h, wo])
Ejemplo n.º 3
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        mask,
                        kv_channels,
                        heads,
                        dropout=0.0,
                        dropout_broadcast_dims=None,
                        name="multihead_attention"):
    """Multihead scaled-dot-product attention with input/output transformations.

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  Args:
    query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
    memory_antecedent: a mtf.Tensor with shape
      [batch, memory_length, io_channels] (optional)
    mask: mask Tensor (see attention_mask())
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    dropout: a floating point value
    dropout_broadcast_dims: an optional list of mtf.Dimension
    name: an optional string.

  Returns:
    A mtf.Tensor with shape [batch, query_length, io_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
    batch, query_length, io_channels = query_antecedent.shape.dims
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)
        if memory_antecedent is None:
            memory_antecedent = rename_length_to_memory_length(
                query_antecedent, query_length.name)
        memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims
        if memory_batch != batch:
            raise ValueError("memory batch must equal query batch")
        if memory_channels != io_channels:
            raise ValueError("memory channels must equal query channels")
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(
                           [batch, heads, query_length, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))
        o = dot_product_attention(q, k, v, mask, dropout,
                                  dropout_broadcast_dims)
        return mtf.einsum([o, o_var],
                          mtf.TensorShape([batch, query_length, io_channels]))
Ejemplo n.º 4
0
def multihead_self_attention_incremental(query_antecedent,
                                         prev_k,
                                         prev_v,
                                         step_num,
                                         name="multihead_attention"):
  """Incremental self-attention (one decode step).

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  Args:
    query_antecedent: a mtf.Tensor with shape [batch..., io_channels]
    prev_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    prev_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    step_num: mtf Scalar with dtype tf.int32
    name: an optional string.

  Returns:
    y: A mtf.Tensor with shape [batch..., io_channels]
    new_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    new_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
  batch_dims = query_antecedent.shape.dims[:-1]
  io_channels = query_antecedent.shape.dims[-1]
  heads, memory_length, kv_channels = prev_k.shape.dims[-3:]
  with tf.variable_scope(name, default_name="multihead_attention"):
    q_var, k_var, v_var, o_var = multihead_attention_vars(
        query_antecedent.mesh, heads, io_channels, kv_channels,
        query_antecedent.dtype)
    memory_antecedent = query_antecedent
    q = mtf.einsum(
        [query_antecedent, q_var],
        mtf.Shape(batch_dims + [heads, kv_channels]))
    k = mtf.einsum(
        [memory_antecedent, k_var],
        mtf.Shape(batch_dims + [heads, kv_channels]))
    v = mtf.einsum(
        [memory_antecedent, v_var],
        mtf.Shape(batch_dims + [heads, kv_channels]))
    k = prev_k + mtf.multiply(
        k, mtf.one_hot(step_num, memory_length), output_shape=prev_k.shape)
    v = prev_v + mtf.multiply(
        v, mtf.one_hot(step_num, memory_length), output_shape=prev_v.shape)

    mask = mtf.to_float(mtf.greater(mtf.range(
        query_antecedent.mesh, memory_length, dtype=tf.int32), step_num)
                       ) * -1e9
    o = dot_product_attention(q, k, v, mask)
    y = mtf.einsum([o, o_var], query_antecedent.shape)
    return y, k, v
Ejemplo n.º 5
0
        def first_block_attention():
            """Compute attention for the first block."""
            first_q = mtf.slice(q, 0, 1, num_blocks.name)
            first_k = mtf.slice(k, 0, 1, num_blocks.name)
            first_v = mtf.slice(v, 0, 1, num_blocks.name)
            block = first_q.shape.dims[2]

            first_logits = mtf.einsum(
                [first_q, first_k],
                mtf.TensorShape([batch, heads, block, blength, mlength]))
            weights = mtf.softmax(first_logits, mlength)
            first_output = mtf.einsum(
                [weights, first_v],
                mtf.TensorShape([batch, heads, block, blength, kv_channels]))
            return first_output
Ejemplo n.º 6
0
def dense(x,
          output_dim,
          reduced_dims=None,
          expert_dims=None,
          use_bias=True,
          activation=None,
          name=None):
    """Dense layer doing (kernel*x + bias) computation.

  Args:
    x: a mtf.Tensor of shape [..., reduced_dims].
    output_dim: a mtf.Dimension
    reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If
      omitted, we reduce the last dimension.
    expert_dims: an optional list of mtf.Dimension which represent different
      experts. Different experts get different weights.
    use_bias: a boolean, whether to add bias.
    activation: an optional function from mtf.Tensor to mtf.Tensor
    name: a string. variable scope.

  Returns:
    a mtf.Tensor of shape [..., output_dim].
  """
    if expert_dims is None:
        expert_dims = []
    if reduced_dims is None:
        reduced_dims = x.shape.dims[-1:]
    w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim])
    output_shape = mtf.Shape(
        [d for d in x.shape.dims if d not in reduced_dims] + [output_dim])

    with tf.variable_scope(name, default_name="dense"):
        stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5
        w = mtf.get_variable(
            x.mesh,
            "kernel",
            w_shape,
            initializer=tf.random_normal_initializer(stddev=stddev),
            activation_dtype=x.dtype)
        y = mtf.einsum([x, w], output_shape)
        if use_bias:
            b = mtf.get_variable(x.mesh,
                                 "bias",
                                 mtf.Shape(expert_dims + [output_dim]),
                                 initializer=tf.zeros_initializer(),
                                 activation_dtype=x.dtype)
            y += b
        if activation is not None:
            y = activation(y)
        return y
Ejemplo n.º 7
0
def dot_product_attention(q,
                          k,
                          v,
                          mask,
                          dropout=0.0,
                          dropout_broadcast_dims=None):
    """Dot-product attention.

  Args:
    q: Tensor with shape [...., length_q, depth_k]. Typically leading dimensions
      are [batch, heads].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    mask: mask Tensor (see attention_mask())
    dropout: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension

  Returns:
    Tensor with shape [..., length_q, depth_v].
  """
    length_kv = k.shape.dims[-2]
    logits_shape = mtf.TensorShape(q.shape.dims[:-1] + [length_kv])
    logits = mtf.einsum([q, k], logits_shape)
    if mask is not None:
        logits += mask
    weights = mtf.softmax(logits, length_kv)
    if dropout != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)
    depth_v = v.shape.dims[-1]
    outputs_shape = mtf.TensorShape(q.shape.dims[:-1] + [depth_v])
    outputs = mtf.einsum([weights, v], outputs_shape)
    return outputs
Ejemplo n.º 8
0
def moe_v0(inputs,
           hidden_dim,
           output_dim,
           experts_dim,
           loss_coef=1e-3,
           overhead=1.0):
    """Local mixture of experts that works well on TPU.

  See https://arxiv.org/abs/1701.06538

  There are num_experts expert networks, each containing a relu-activated
  hidden layer of size hidden_size, followed by an output projection.

  The number of parameters is thus:
    num_experts * (input_size * hidden_size + hidden_size * output_size)

  The input is 3d: [batch, length, depth], consisting of the representations
  of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    as opposed to on individual sequences.  This would allow more freedom
    for individual sequences to be unbalanced.  Unfortunately, that would
    slow down our hacked-up gather-by-matmul implementation.

    TODO(noam): There is no real reason for a single sequence to be the unit
      of equal allocation.  Reshaping the inputs would allow us to pick a
      different unit of equal allocation.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.  We also want to integrate this
  gating/dispatching logic into multi-device mixtures-of-experts.

  Args:
    inputs: a mtf.Tensor with shape [batch_dim, length_dim, input_dim]
    hidden_dim: a mtf.Dimension
    output_dim: a mtf.Dimension
    experts_dim: a mtf.Dimension
    loss_coef: a float scalar
    overhead: multiplicative factor of how much spare capacity to assign

  Returns:
    outputs: a Tensor with shape [batch_dim, length_dim, output_dim]
    loss: a mtf scalar
  """
    batch_dim, length_dim, input_dim = inputs.shape.dims

    # Each sequence sends expert_capacity positions to each expert.
    expert_capacity = min(
        length_dim.size,
        int((length_dim.size * 2 * overhead) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size)

    # This is the learned gating function.
    # shape = [batch_dim, length_dim, experts_dim_unsplit]
    gates = mtf.softmax(dense(inputs, experts_dim_unsplit),
                        experts_dim_unsplit)

    assignment_shape = mtf.TensorShape(
        [batch_dim, length_dim, experts_dim_unsplit, expert_capacity_dim])

    backward_assignment = mtf.slicewise(functools.partial(
        _truncated_top_2_gating, expert_capacity=expert_capacity), [gates],
                                        output_shape=assignment_shape,
                                        splittable_dims=[batch_dim],
                                        name="backward_assignment")

    forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool),
                                  inputs.dtype)

    # put num_experts dimension first to make split easier in alltoall
    expert_inputs = mtf.einsum([inputs, forward_assignment],
                               mtf.TensorShape([
                                   experts_dim_unsplit, batch_dim,
                                   expert_capacity_dim, input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.TensorShape(
            [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim]))

    # Now feed the expert inputs through the experts.
    h = dense(expert_inputs,
              hidden_dim,
              expert_dims=[experts_dim],
              activation=mtf.relu,
              name="x0")
    expert_output = dense(h, output_dim, expert_dims=[experts_dim], name="x1")

    expert_output = mtf.reshape(
        expert_output,
        mtf.TensorShape(
            [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

    output = mtf.einsum([expert_output, backward_assignment],
                        mtf.TensorShape([batch_dim, length_dim, output_dim]))

    importance = mtf.reduce_sum(backward_assignment,
                                output_shape=mtf.TensorShape(
                                    [batch_dim, experts_dim_unsplit]))

    loss = cv_squared(importance) * loss_coef
    return output, loss
Ejemplo n.º 9
0
def masked_local_attention_1d(query_antecedent,
                              memory_antecedent,
                              kv_channels,
                              heads,
                              block_length=128,
                              name=None):
    """Attention to the source position and a neighborhood to the left of it.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
    memory_antecedent: a mtf.Tensor with shape
      [batch, memory_length, io_channels] (optional). Currently, memory_length
      must have the same size as query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    block_length: an integer, representing receptive fields for attention.
    name: an optional string.

  Returns:
    a Tensor of shape [batch, query_length, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        batch, query_length, io_channels = query_antecedent.shape.dims
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)

        if memory_antecedent is None:
            memory_antecedent = rename_length_to_memory_length(
                query_antecedent, query_length.name)
        memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims
        if memory_batch != batch:
            raise ValueError("memory batch must equal query batch")
        if memory_channels != io_channels:
            raise ValueError("memory channels must equal query channels")

        # Get query q, keys k and values v.
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(
                           [batch, heads, query_length, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))

        # Let's assume for now we don't have padding and the block length equally
        # divides the memory length.
        block_length = (query_length.size if
                        query_length.size < block_length * 2 else block_length)
        blength = mtf.Dimension("block_length", block_length)
        mlength = mtf.Dimension("mem_block_length", block_length)
        num_blocks = mtf.Dimension("num_blocks",
                                   query_length.size // block_length)

        q = mtf.reshape(
            q,
            mtf.TensorShape([batch, heads, num_blocks, blength, kv_channels]))
        k = mtf.reshape(
            k,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))
        v = mtf.reshape(
            v,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))

        # compute attention for the first query block.
        def first_block_attention():
            """Compute attention for the first block."""
            first_q = mtf.slice(q, 0, 1, num_blocks.name)
            first_k = mtf.slice(k, 0, 1, num_blocks.name)
            first_v = mtf.slice(v, 0, 1, num_blocks.name)
            block = first_q.shape.dims[2]

            first_logits = mtf.einsum(
                [first_q, first_k],
                mtf.TensorShape([batch, heads, block, blength, mlength]))
            weights = mtf.softmax(first_logits, mlength)
            first_output = mtf.einsum(
                [weights, first_v],
                mtf.TensorShape([batch, heads, block, blength, kv_channels]))
            return first_output

        # Attention for first block, since query_length = key_length.
        first_output = first_block_attention()

        # Concatenate two adjacent blocks to compute the overlapping memory block.
        def local(x):
            """Helper function to get memory blocks."""
            prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name)
            cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name)
            local_block = mtf.concat([prev_block, cur_block], mlength.name)
            return local_block

        local_k = local(k)
        local_v = local(v)
        mblocks = local_k.shape.dims[2]
        mlength = local_k.shape.dims[3]
        # Calculate the causal mask to avoid peeking into the future. We compute
        # this once and reuse it for all blocks since the block_size is known.
        mask = attention_bias_local_block(query_antecedent.mesh, blength,
                                          mlength)

        # Remove the first block from q since we already computed that.
        tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name)

        # Compatibility between q and k for rest of the blocks.
        # Shape [batch, heads, num_blocks - 1, block_length, local_length]
        attention = mtf.einsum([tail_q, local_k],
                               mtf.TensorShape(
                                   [batch, heads, mblocks, blength, mlength]))
        attention += mask
        attention = mtf.softmax(attention, mlength)

        # Run attention for rest of the blocks.
        # Shape [batch, heads, num_blocks-1, block_length, kv_channels]
        output = mtf.einsum([attention, local_v],
                            mtf.TensorShape(
                                [batch, heads, mblocks, blength, kv_channels]))
        # Now concatenate the first and rest of the blocks.
        final_output = mtf.concat([first_output, output], num_blocks.name)
        final_output = mtf.reshape(
            final_output,
            mtf.TensorShape([batch, heads, query_length, kv_channels]))
        return mtf.einsum([final_output, o_var],
                          mtf.TensorShape([batch, query_length, io_channels]))
Ejemplo n.º 10
0
def local_self_attention_spatial_blocks(query_antecedent,
                                        kv_channels,
                                        heads,
                                        memory_w_dim=None,
                                        mask_right=False,
                                        name=None):
    """Attention to the source position and a neighborhood to the left or right.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape
      [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]
      must have the same size as query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    memory_w_dim: mtf Dimension, for the memory width block.
    mask_right: bool, flag specifying whether we mask out attention to the right
      for the decoder.
    name: an optional string.

  Returns:
    a Tensor of shape
        [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent]):

        w_dim, io_channels = query_antecedent.shape.dims[-2:]
        batch, num_w_blocks = query_antecedent.shape.dims[:2]
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)

        # Rename dimensions for the memory height and width.
        memory_antecedent = mtf.rename_dimension(query_antecedent, w_dim.name,
                                                 memory_w_dim.name)

        # Call einsum over the query and memory to get query q, keys k and values v.
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.Shape(
                           [batch, heads, num_w_blocks, w_dim, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.Shape(
                           [batch, heads, num_w_blocks, w_dim, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.Shape(
                           [batch, heads, num_w_blocks, w_dim, kv_channels]))

        # Halo exchange for memory blocks.
        if memory_w_dim is not None:
            k, v = local_1d_halo_exchange(k, v, num_w_blocks, w_dim,
                                          memory_w_dim, mask_right)

        # Calculate the causal mask to avoid peeking into the future. We compute
        # this once and reuse it for all blocks since the block_size is known.
        mask = None
        if mask_right:
            mask = attention_bias_local_block(query_antecedent.mesh, w_dim,
                                              memory_w_dim)

        output = dot_product_attention(q, k, v, mask=mask)

        return mtf.einsum([output, o_var],
                          mtf.Shape([batch, num_w_blocks, w_dim, io_channels]))
Ejemplo n.º 11
0
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if self.has_input:
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf_layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.num_encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num in xrange(hparams.num_decoder_layers):
        with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num):
          q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars(
              mesh, self.heads_dim, self.model_dim,
              self.kv_dim, self.activation_dtype)
          k = mtf.einsum(
              [encoder_output, k_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
          v = mtf.einsum(
              [encoder_output, v_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
        encdec_tensors.append((q_var, o_var, k, v))
      partial_targets = None
    else:
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)

    if hparams.beam_size == 1:
      ids_shape = mtf.Shape([self.batch_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_kv_states = (
        [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)]
        * (2 * hparams.num_decoder_layers))
    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      self_attention_k = states[:hparams.num_decoder_layers]
      self_attention_v = states[hparams.num_decoder_layers:]
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_self_attention_k, new_self_attention_v = (
            self._decoder_layer_stack_incremental(
                x,
                step_num,
                encdec_tensors,
                self_attention_k,
                self_attention_v,
                encdec_attention_mask=encoder_attention_mask))
      logits = mtf.matmul(x, softmax_var)
      return logits, new_self_attention_k + new_self_attention_v

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf_beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_kv_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if self.has_input:
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf_beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_kv_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
Ejemplo n.º 12
0
def transformer_moe_layer_v1(inputs, output_dim, hparams, train):
    """Local mixture of experts that works well on TPU.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  The number of parameters in the gating network is:
    (input_dim.size * hparams.num_experts) +

  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean

  Returns:
    outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    orig_inputs = inputs
    input_dim = inputs.shape.dims[-1]
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)
    group_size_dim = mtf.Dimension("group", hparams.moe_group_size)
    batch_dim = mtf.Dimension(
        orig_inputs.shape[0].name,
        orig_inputs.shape.size // (group_size_dim.size * input_dim.size))
    inputs = mtf.reshape(inputs, [batch_dim, group_size_dim, input_dim])

    # Each sequence sends expert_capacity positions to each expert.
    capacity_factor = (hparams.moe_capacity_factor_train
                       if train else hparams.moe_capacity_factor_eval)
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size)

    if hparams.moe_gating == "top_2":
        dispatch_tensor, combine_tensor, loss = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # put num_experts dimension first to make split easier in alltoall
    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   experts_dim_unsplit, batch_dim,
                                   expert_capacity_dim, input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape(
            [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim]))

    # Now feed the expert inputs through the experts.
    h = mtf_layers.dense(expert_inputs,
                         hidden_dim,
                         expert_dims=[experts_dim],
                         activation=mtf.relu,
                         use_bias=False,
                         name="x0")
    expert_output = mtf_layers.dense(h,
                                     output_dim,
                                     expert_dims=[experts_dim],
                                     use_bias=False,
                                     name="x1")

    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape(
            [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape([batch_dim, group_size_dim, output_dim]))

    output = mtf.reshape(output, orig_inputs.shape.dims[:-1] + [output_dim])

    return output, loss * hparams.moe_loss_coef
Ejemplo n.º 13
0
def transformer_moe_layer_v2(inputs, output_dim, hparams, train):
    """2-level mixture of experts.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_capacity_factor_second_level: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  One set of params for experts in first level and different of hparams
  per expert in the second level.
  The number of parameters in the gating network is:
    (input_dim.size * (hparams.num_experts) +
      (moe_hidden_size * hparams.num_experts) * hparams.num_experts


  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-3 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  a, b: batch size
  l: original sequence length
  m: input depth
  n: output depth
  g, h: number of groups
  s, t: group size
  x, y: number of experts
  c, d: expert capacity

  input: [a0, b1, l, m]
  input: [a0, g1, s, m]
  dispatch_tensor_x: [a0, g1, s, x, c]
  expert_input: [a0, g1, x, c, m]
  alltoall: [a0, g, x1, c, m]
  alltoall: [a0, g, x1, c, m]
  transpose: [x1, a0, g, c, m]
  reshape: [x1, h0, s, m]
  assignment2: [x1, h0, t, y, d]
  expert_input2: [x1, h0, y, d, m]
  alltoall: [x1, h, y0, d, m]
  ...
  reverse of that

  gating params 0: [m, x]
  gating params 1: [x1, m, y]

  expert params:
     [x1, y0, m, hidden]
     [x1, y0, hidden, n]

  Args:
    inputs: a mtf.Tensor with shape [a, b, l, m]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean

  Returns:
    outputs: a Tensor with shape [a, b, l, n]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
    if insert_outer_batch_dim:
        inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] +
                             inputs.shape.dims)

    assert len(hparams.moe_num_experts) == 2
    a0, b1, l, m = inputs.shape.dims
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
    y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
    x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
    y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
    n = output_dim

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (g.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        b1.size * l.size, hparams.moe_group_size,
        _tensor_dim_to_mesh_dim_size(hparams, b1))
    g1 = mtf.Dimension(b1.name, num_groups)
    g = mtf.Dimension(b1.name + "_unsplit", g1.size)
    s = mtf.Dimension("group_size_x", group_size)

    # Each sequence sends (at most?) expert_capacity positions to each expert.
    # Static expert_capacity dimension is needed for expert batch sizes
    capacity_factor = (hparams.moe_capacity_factor_train
                       if train else hparams.moe_capacity_factor_eval)
    expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
    c = mtf.Dimension("expert_capacity_x", expert_capacity)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (h.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        a0.size * g.size * c.size, hparams.moe_group_size,
        _tensor_dim_to_mesh_dim_size(hparams, a0))
    t = mtf.Dimension("group_size_y", group_size)
    h0 = mtf.Dimension(a0.name, num_groups)
    h = mtf.Dimension(a0.name + "_unsplit", h0.size)

    expert_capacity = min(
        t.size,
        int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
    d = mtf.Dimension("expert_capacity_y", expert_capacity)

    # First level of expert routing
    # Reshape the inner batch size to a multiple of group_dim g1 and
    # group_size_dim s.
    inputs = mtf.reshape(inputs, [a0, g1, s, m])

    # Get the assignments for the first level.
    # dispatch_tensor_x has shape [a0, g1, s, x, c]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=x,
            expert_capacity_dim=c,
            hparams=hparams,
            train=train)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x],
                                 [x, a0, g1, c, m])

    # we construct an "importance" Tensor for the inputs to the second-level
    # gating.  The importance of an input is 1.0 if it represents the
    # first-choice expert-group and 0.5 if it represents the second-choice expert
    # group.  This is used by the second-level gating.
    importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
    importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) +
                        mtf.to_float(mtf.greater(importance, 0.0)))

    # First level, all to all. Here we change the split dimension from g1 to x1.
    expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c,
                                                              m]))
    importance = mtf.reshape(importance, [x1, a0, g, c])

    # Second level of expert routing
    # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
    # and group_size_dim t.
    inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
    importance = mtf.reshape(importance, [x1, h0, t])

    # Get the assignments for the second level.
    # dispatch_tensor_y has shape [x1, h0, t, y, d]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
            inputs=inputs_y,
            outer_expert_dims=[x1],
            experts_dim=y,
            expert_capacity_dim=d,
            hparams=hparams,
            train=train,
            importance=importance)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y],
                                 [y, x1, h0, d, m])

    # Second level, all to all. Here we change the split dimension from h0 to y0.
    expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d,
                                                              m]))

    # Now feed the expert inputs through the experts.
    hidden_output = mtf_layers.dense(expert_inputs_y,
                                     hidden_dim,
                                     expert_dims=[y0, x1],
                                     activation=mtf.relu,
                                     use_bias=False,
                                     name="expert0")
    expert_output = mtf_layers.dense(hidden_output,
                                     output_dim,
                                     expert_dims=[y0, x1],
                                     use_bias=False,
                                     name="expert1")

    # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
    # expert_output has shape [y0, x1, h, d, n]

    # alltoall
    expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n]))

    # combine results from inner level
    output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])

    # Reshape the combined tensor from inner level to now contain outer_batch_dim
    # a0 and group_dim g
    output = mtf.reshape(output_y, [x1, a0, g, c, n])

    # alltoall from expert_dim x to group_dim g1
    expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))

    # combine results from outer level
    output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])

    # Reshape the combined tensor to now contain inner_batch_dim
    # b1 and the original sequence length
    output = mtf.reshape(output_x, [a0, b1, l, n])
    if insert_outer_batch_dim:
        output = mtf.reshape(output, [b1, l, n])
    return output, (loss_outer + loss_inner) * hparams.moe_loss_coef