def attention_bias_local_2d_block(mesh, h_dim, w_dim, memory_h_dim, memory_w_dim, dtype=tf.int32): """Bias for attention for local blocks where attention to right is disallowed. Create the bias matrix by using two separate masks, one for the memory part which doesn't overlap with the query and second which interacts with the query and should be disallowed to look to the right of the current query position. Args: mesh: a MeshTensorflow object h_dim: a mtf.Dimension w_dim: a mtf.Dimension memory_h_dim: a mtf.Dimension memory_w_dim: a mtf.Dimension dtype: a tf.dtype Returns: a mtf.Tensor with shape [block_length, memory_length] """ memory_height = mtf.Dimension(memory_h_dim.name, h_dim.size) memory_width = mtf.Dimension(memory_w_dim.name, w_dim.size) mask_top_visible = mtf.zeros(mesh, [h_dim, memory_height], dtype=dtype) mask_left_visible = mtf.zeros(mesh, [w_dim, memory_width], dtype=dtype) mask_query = mtf.greater(mtf.range(mesh, memory_height, dtype=tf.int32), mtf.range(mesh, memory_width, dtype=dtype)) width_mask = mtf.concat([mask_left_visible, mask_query], memory_width.name) mask = mtf.cast(mtf.concat([mask_top_visible, width_mask], memory_height.name), dtype=tf.float32) * -1e9 return mask
def attention_bias_local_block(mesh, block_length, memory_length, dtype=tf.int32): """Bias for attention for local blocks where attention to right is disallowed. Create the bias matrix by using two separate masks, one for the memory part which doesn't overlap with the query and second which interacts with the query and should be disallowed to look to the right of the current query position. Args: mesh: a MeshTensorflow object block_length: a mtf.Dimension memory_length: a mtf.Dimension dtype: a tf.dtype Returns: a mtf.Tensor with shape [block_length, memory_length] """ memory_length = mtf.Dimension(memory_length.name, block_length.size) memory_mask = mtf.zeros(mesh, [block_length, memory_length], dtype=dtype) mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype), mtf.range(mesh, memory_length, dtype=dtype)), dtype=dtype) mask = mtf.cast(mtf.concat([memory_mask, mask], memory_length.name), dtype=tf.float32) * -1e9 return mask
def _my_concat(a, b): a = mtf.rename_dimension(a, "beam", "triple_beam") b = mtf.rename_dimension(b, "double_beam", "triple_beam") return mtf.concat([a, b], "triple_beam")
def local(x): """Helper function to get memory blocks.""" prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name) cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name) local_block = mtf.concat([prev_block, cur_block], mlength.name) return local_block
def masked_local_attention_1d(query_antecedent, memory_antecedent, kv_channels, heads, block_length=128, name=None): """Attention to the source position and a neighborhood to the left of it. The sequence is divided into blocks of length block_size. Attention for a given query position can only see memory positions less than or equal to the query position, in the corresponding block and the previous block. Args: query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels] memory_antecedent: a mtf.Tensor with shape [batch, memory_length, io_channels] (optional). Currently, memory_length must have the same size as query_length, but a different name. kv_channels: a mtf.Dimension (the size of the key and value vectors) heads: a mtf.Dimension (the number of heads) block_length: an integer, representing receptive fields for attention. name: an optional string. Returns: a Tensor of shape [batch, query_length, io_channels] Raises: ValueError: if channels or depth don't match. """ with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): batch, query_length, io_channels = query_antecedent.shape.dims q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) if memory_antecedent is None: memory_antecedent = rename_length_to_memory_length( query_antecedent, query_length.name) memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims if memory_batch != batch: raise ValueError("memory batch must equal query batch") if memory_channels != io_channels: raise ValueError("memory channels must equal query channels") # Get query q, keys k and values v. q = mtf.einsum([query_antecedent, q_var], mtf.Shape([batch, heads, query_length, kv_channels])) k = mtf.einsum([memory_antecedent, k_var], mtf.Shape([batch, heads, memory_length, kv_channels])) v = mtf.einsum([memory_antecedent, v_var], mtf.Shape([batch, heads, memory_length, kv_channels])) # Let's assume for now we don't have padding and the block length equally # divides the memory length. block_length = (query_length.size if query_length.size < block_length * 2 else block_length) blength = mtf.Dimension("block_length", block_length) mlength = mtf.Dimension("mem_block_length", block_length) num_blocks = mtf.Dimension("num_blocks", query_length.size // block_length) q = mtf.reshape( q, mtf.Shape([batch, heads, num_blocks, blength, kv_channels])) k = mtf.reshape( k, mtf.Shape([batch, heads, num_blocks, mlength, kv_channels])) v = mtf.reshape( v, mtf.Shape([batch, heads, num_blocks, mlength, kv_channels])) # compute attention for the first query block. def first_block_attention(): """Compute attention for the first block.""" first_q = mtf.slice(q, 0, 1, num_blocks.name) first_k = mtf.slice(k, 0, 1, num_blocks.name) first_v = mtf.slice(v, 0, 1, num_blocks.name) first_output = dot_product_attention(first_q, first_k, first_v, mask=None) return first_output # Attention for first block, since query_length = key_length. first_output = first_block_attention() # Concatenate two adjacent blocks to compute the overlapping memory block. def local(x): """Helper function to get memory blocks.""" prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name) cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name) local_block = mtf.concat([prev_block, cur_block], mlength.name) return local_block local_k = local(k) local_v = local(v) # Calculate the causal mask to avoid peeking into the future. We compute # this once and reuse it for all blocks since the block_size is known. mlength = local_k.shape.dims[3] mask = attention_bias_local_block(query_antecedent.mesh, blength, mlength) # Remove the first block from q since we already computed that. tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name) tail_output = dot_product_attention(tail_q, local_k, local_v, mask=mask) # Now concatenate the first and rest of the blocks. final_output = mtf.concat([first_output, tail_output], num_blocks.name) final_output = mtf.reshape( final_output, mtf.Shape([batch, heads, query_length, kv_channels])) return mtf.einsum([final_output, o_var], mtf.Shape([batch, query_length, io_channels]))