def first_block_attention(): """Compute attention for the first block.""" first_q = mtf.slice(q, 0, 1, num_blocks.name) first_k = mtf.slice(k, 0, 1, num_blocks.name) first_v = mtf.slice(v, 0, 1, num_blocks.name) first_output = dot_product_attention(first_q, first_k, first_v, mask=None) return first_output
def first_block_attention(): """Compute attention for the first block.""" first_q = mtf.slice(q, 0, 1, num_blocks.name) first_k = mtf.slice(k, 0, 1, num_blocks.name) first_v = mtf.slice(v, 0, 1, num_blocks.name) block = first_q.shape.dims[2] first_logits = mtf.einsum( [first_q, first_k], mtf.TensorShape([batch, heads, block, blength, mlength])) weights = mtf.softmax(first_logits, mlength) first_output = mtf.einsum( [weights, first_v], mtf.TensorShape([batch, heads, block, blength, kv_channels])) return first_output
def local(x): """Helper function to get memory blocks.""" prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name) cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name) local_block = mtf.concat([prev_block, cur_block], mlength.name) return local_block
def masked_local_attention_1d(query_antecedent, memory_antecedent, kv_channels, heads, block_length=128, name=None): """Attention to the source position and a neighborhood to the left of it. The sequence is divided into blocks of length block_size. Attention for a given query position can only see memory positions less than or equal to the query position, in the corresponding block and the previous block. Args: query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels] memory_antecedent: a mtf.Tensor with shape [batch, memory_length, io_channels] (optional). Currently, memory_length must have the same size as query_length, but a different name. kv_channels: a mtf.Dimension (the size of the key and value vectors) heads: a mtf.Dimension (the number of heads) block_length: an integer, representing receptive fields for attention. name: an optional string. Returns: a Tensor of shape [batch, query_length, io_channels] Raises: ValueError: if channels or depth don't match. """ with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): batch, query_length, io_channels = query_antecedent.shape.dims q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) if memory_antecedent is None: memory_antecedent = rename_length_to_memory_length( query_antecedent, query_length.name) memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims if memory_batch != batch: raise ValueError("memory batch must equal query batch") if memory_channels != io_channels: raise ValueError("memory channels must equal query channels") # Get query q, keys k and values v. q = mtf.einsum([query_antecedent, q_var], mtf.TensorShape( [batch, heads, query_length, kv_channels])) k = mtf.einsum([memory_antecedent, k_var], mtf.TensorShape( [batch, heads, memory_length, kv_channels])) v = mtf.einsum([memory_antecedent, v_var], mtf.TensorShape( [batch, heads, memory_length, kv_channels])) # Let's assume for now we don't have padding and the block length equally # divides the memory length. block_length = (query_length.size if query_length.size < block_length * 2 else block_length) blength = mtf.Dimension("block_length", block_length) mlength = mtf.Dimension("mem_block_length", block_length) num_blocks = mtf.Dimension("num_blocks", query_length.size // block_length) q = mtf.reshape( q, mtf.TensorShape([batch, heads, num_blocks, blength, kv_channels])) k = mtf.reshape( k, mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels])) v = mtf.reshape( v, mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels])) # compute attention for the first query block. def first_block_attention(): """Compute attention for the first block.""" first_q = mtf.slice(q, 0, 1, num_blocks.name) first_k = mtf.slice(k, 0, 1, num_blocks.name) first_v = mtf.slice(v, 0, 1, num_blocks.name) block = first_q.shape.dims[2] first_logits = mtf.einsum( [first_q, first_k], mtf.TensorShape([batch, heads, block, blength, mlength])) weights = mtf.softmax(first_logits, mlength) first_output = mtf.einsum( [weights, first_v], mtf.TensorShape([batch, heads, block, blength, kv_channels])) return first_output # Attention for first block, since query_length = key_length. first_output = first_block_attention() # Concatenate two adjacent blocks to compute the overlapping memory block. def local(x): """Helper function to get memory blocks.""" prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name) cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name) local_block = mtf.concat([prev_block, cur_block], mlength.name) return local_block local_k = local(k) local_v = local(v) mblocks = local_k.shape.dims[2] mlength = local_k.shape.dims[3] # Calculate the causal mask to avoid peeking into the future. We compute # this once and reuse it for all blocks since the block_size is known. mask = attention_bias_local_block(query_antecedent.mesh, blength, mlength) # Remove the first block from q since we already computed that. tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name) # Compatibility between q and k for rest of the blocks. # Shape [batch, heads, num_blocks - 1, block_length, local_length] attention = mtf.einsum([tail_q, local_k], mtf.TensorShape( [batch, heads, mblocks, blength, mlength])) attention += mask attention = mtf.softmax(attention, mlength) # Run attention for rest of the blocks. # Shape [batch, heads, num_blocks-1, block_length, kv_channels] output = mtf.einsum([attention, local_v], mtf.TensorShape( [batch, heads, mblocks, blength, kv_channels])) # Now concatenate the first and rest of the blocks. final_output = mtf.concat([first_output, output], num_blocks.name) final_output = mtf.reshape( final_output, mtf.TensorShape([batch, heads, query_length, kv_channels])) return mtf.einsum([final_output, o_var], mtf.TensorShape([batch, query_length, io_channels]))