Example #1
0
def attention_bias_local_2d_block(mesh,
                                  h_dim,
                                  w_dim,
                                  memory_h_dim,
                                  memory_w_dim,
                                  dtype=tf.int32):
    """Bias for attention for local blocks where attention to right is disallowed.

  Create the bias matrix by using two separate masks, one for the memory part
  which doesn't overlap with the query and second which interacts with the query
  and should be disallowed to look to the right of the current query position.

  Args:
    mesh: a MeshTensorflow object
    h_dim: a mtf.Dimension
    w_dim: a mtf.Dimension
    memory_h_dim: a mtf.Dimension
    memory_w_dim: a mtf.Dimension
    dtype: a tf.dtype

  Returns:
    a mtf.Tensor with shape [block_length, memory_length]
  """
    memory_height = mtf.Dimension(memory_h_dim.name, h_dim.size)
    memory_width = mtf.Dimension(memory_w_dim.name, w_dim.size)
    mask_top_visible = mtf.zeros(mesh, [h_dim, memory_height], dtype=dtype)
    mask_left_visible = mtf.zeros(mesh, [w_dim, memory_width], dtype=dtype)
    mask_query = mtf.greater(mtf.range(mesh, memory_height, dtype=tf.int32),
                             mtf.range(mesh, memory_width, dtype=dtype))
    width_mask = mtf.concat([mask_left_visible, mask_query], memory_width.name)
    mask = mtf.cast(mtf.concat([mask_top_visible, width_mask],
                               memory_height.name),
                    dtype=tf.float32) * -1e9
    return mask
Example #2
0
def attention_bias_local_block(mesh,
                               block_length,
                               memory_length,
                               dtype=tf.int32):
    """Bias for attention for local blocks where attention to right is disallowed.

  Create the bias matrix by using two separate masks, one for the memory part
  which doesn't overlap with the query and second which interacts with the query
  and should be disallowed to look to the right of the current query position.

  Args:
    mesh: a MeshTensorflow object
    block_length: a mtf.Dimension
    memory_length: a mtf.Dimension
    dtype: a tf.dtype

  Returns:
    a mtf.Tensor with shape [block_length, memory_length]
  """
    memory_length = mtf.Dimension(memory_length.name, block_length.size)
    memory_mask = mtf.zeros(mesh, [block_length, memory_length], dtype=dtype)

    mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype),
                             mtf.range(mesh, memory_length, dtype=dtype)),
                    dtype=dtype)
    mask = mtf.cast(mtf.concat([memory_mask, mask], memory_length.name),
                    dtype=tf.float32) * -1e9
    return mask
Example #3
0
 def _my_concat(a, b):
   a = mtf.rename_dimension(a, "beam", "triple_beam")
   b = mtf.rename_dimension(b, "double_beam", "triple_beam")
   return mtf.concat([a, b], "triple_beam")
Example #4
0
 def local(x):
     """Helper function to get memory blocks."""
     prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name)
     cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name)
     local_block = mtf.concat([prev_block, cur_block], mlength.name)
     return local_block
Example #5
0
def masked_local_attention_1d(query_antecedent,
                              memory_antecedent,
                              kv_channels,
                              heads,
                              block_length=128,
                              name=None):
    """Attention to the source position and a neighborhood to the left of it.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
    memory_antecedent: a mtf.Tensor with shape
      [batch, memory_length, io_channels] (optional). Currently, memory_length
      must have the same size as query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    block_length: an integer, representing receptive fields for attention.
    name: an optional string.

  Returns:
    a Tensor of shape [batch, query_length, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        batch, query_length, io_channels = query_antecedent.shape.dims
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)

        if memory_antecedent is None:
            memory_antecedent = rename_length_to_memory_length(
                query_antecedent, query_length.name)
        memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims
        if memory_batch != batch:
            raise ValueError("memory batch must equal query batch")
        if memory_channels != io_channels:
            raise ValueError("memory channels must equal query channels")

        # Get query q, keys k and values v.
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.Shape([batch, heads, query_length, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.Shape([batch, heads, memory_length, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.Shape([batch, heads, memory_length, kv_channels]))

        # Let's assume for now we don't have padding and the block length equally
        # divides the memory length.
        block_length = (query_length.size if
                        query_length.size < block_length * 2 else block_length)
        blength = mtf.Dimension("block_length", block_length)
        mlength = mtf.Dimension("mem_block_length", block_length)
        num_blocks = mtf.Dimension("num_blocks",
                                   query_length.size // block_length)

        q = mtf.reshape(
            q, mtf.Shape([batch, heads, num_blocks, blength, kv_channels]))
        k = mtf.reshape(
            k, mtf.Shape([batch, heads, num_blocks, mlength, kv_channels]))
        v = mtf.reshape(
            v, mtf.Shape([batch, heads, num_blocks, mlength, kv_channels]))

        # compute attention for the first query block.
        def first_block_attention():
            """Compute attention for the first block."""
            first_q = mtf.slice(q, 0, 1, num_blocks.name)
            first_k = mtf.slice(k, 0, 1, num_blocks.name)
            first_v = mtf.slice(v, 0, 1, num_blocks.name)
            first_output = dot_product_attention(first_q,
                                                 first_k,
                                                 first_v,
                                                 mask=None)
            return first_output

        # Attention for first block, since query_length = key_length.
        first_output = first_block_attention()

        # Concatenate two adjacent blocks to compute the overlapping memory block.
        def local(x):
            """Helper function to get memory blocks."""
            prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name)
            cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name)
            local_block = mtf.concat([prev_block, cur_block], mlength.name)
            return local_block

        local_k = local(k)
        local_v = local(v)
        # Calculate the causal mask to avoid peeking into the future. We compute
        # this once and reuse it for all blocks since the block_size is known.
        mlength = local_k.shape.dims[3]
        mask = attention_bias_local_block(query_antecedent.mesh, blength,
                                          mlength)

        # Remove the first block from q since we already computed that.
        tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name)

        tail_output = dot_product_attention(tail_q,
                                            local_k,
                                            local_v,
                                            mask=mask)

        # Now concatenate the first and rest of the blocks.
        final_output = mtf.concat([first_output, tail_output], num_blocks.name)
        final_output = mtf.reshape(
            final_output, mtf.Shape([batch, heads, query_length, kv_channels]))
        return mtf.einsum([final_output, o_var],
                          mtf.Shape([batch, query_length, io_channels]))