def multihead_attention(self, memory): seq_len = common_layers.shape_list(memory)[1] q = tf.layers.dense(memory, self._num_units, name="query") k = tf.layers.dense(memory, self._num_units, name="key") v = tf.layers.dense(memory, self._num_units, name="value") bias = None # bias = common_attention.attention_bias_lower_triangle(seq_len) q = common_attention.split_heads( q, self._num_heads) # [batch_size, heads, q_len, hidden_size/heads] k = common_attention.split_heads(k, self._num_heads) v = common_attention.split_heads(v, self._num_heads) context = common_attention.dot_product_attention(q, k, v, bias) memory = common_attention.combine_heads( context) # [batch_size, seq_len, hidden_size] return memory
def attn(image_feat, query, hparams, name="attn", save_weights_to=None, make_image_summary=True): """Attention on image feature with question as query.""" with tf.variable_scope(name, "attn", values=[image_feat, query]): total_key_depth = hparams.attention_key_channels or hparams.hidden_size total_value_depth = hparams.attention_value_channels or hparams.hidden_size num_heads = hparams.num_heads query = tf.expand_dims(query, 1) q, k, v = common_attention.compute_qkv( query, image_feat, total_key_depth, total_value_depth, ) q = common_attention.split_heads(q, num_heads) k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) if hparams.scale_dotproduct: key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 # image_feat is input as v x = common_attention.dot_product_attention( q, k, v, None, dropout_rate=hparams.attention_dropout, image_shapes=None, save_weights_to=save_weights_to, make_image_summary=make_image_summary) x = common_attention.combine_heads(x) return tf.squeeze(x, axis=1)
def multihead_attention(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, shared_rel=False, max_relative_position=None, image_shapes=None, attention_type="dot_product", block_length=128, block_width=128, q_filter_width=1, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", cache=None, gap_size=0, num_memory_blocks=2, name="multihead_attention", save_weights_to=None, make_image_summary=True, dropout_broadcast_dims=None, max_length=None, vars_3d=False, scale_dotproduct=True, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] or None bias: bias Tensor (see attention_bias()) total_key_depth: an integer total_value_depth: an integer output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number shared_rel: boolean to share relative embeddings max_relative_position: Maximum distance between inputs to generate unique relation embeddings for. Only relevant when using "dot_product_relative" attention. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() attention_type: a string, either "dot_product", "dot_product_relative", "local_mask_right", "local_unmasked", "masked_dilated_1d", "unmasked_dilated_1d", graph, or any attention function with the signature (query, key, value, **kwargs) block_length: an integer - relevant for "local_mask_right" block_width: an integer - relevant for "local_unmasked" q_filter_width: An integer specifying how wide you want the query to be. kv_filter_width: An integer specifying how wide you want the keys and values to be. q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID": no padding. cache: dict containing Tensors which are the results of previous attentions, used for fast decoding. Expects the dict to contrain two keys ('k' and 'v'), for the initial call the values for these keys should be empty Tensors of the appropriate shape. 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] gap_size: Integer option for dilated attention to indicate spacing between memory blocks. num_memory_blocks: Integer option to indicate how many memory blocks to look at. name: an optional string. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. max_length: an integer - needed by relative attention vars_3d: use 3-dimensional variables for input/output transformations scale_dotproduct: whether to normalize the attention product. **kwargs (dict): Parameters for the attention function Caching: WARNING: For decoder self-attention, i.e. when memory_antecedent == None, the caching assumes that the bias contains future masking. The caching works by saving all the previous key and value values so that you are able to send just the last query location to this attention function. I.e. if the cache dict is provided it assumes the query is of the shape [batch_size, 1, hidden_dim] rather than the full memory. Returns: The result of the attention transformation. The output shape is [batch_size, length_q, hidden_dim] unless the cache dict is provided in which case only the last memory position is calculated and the output shape is [batch_size, 1, hidden_dim] Optionally returns an additional loss parameters (ex: load balance loss for the experts) returned by the attention_type function. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) vars_3d_num_heads = num_heads if vars_3d else 0 with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): if cache is None or memory_antecedent is None: q, k, v = common_attention.compute_qkv( query_antecedent, memory_antecedent, total_key_depth, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding, vars_3d_num_heads=vars_3d_num_heads) if cache is not None: if attention_type != "dot_product": # TODO(petershaw): Support caching when using relative position # representations, i.e. "dot_product_relative" attention. raise NotImplementedError( "Caching is not guaranteed to work with attention types other than" " dot_product.") if bias is None: raise ValueError( "Bias required for caching. See function docstring " "for details.") if memory_antecedent is not None: # Encoder-Decoder Attention Cache q = common_attention.compute_attention_component( query_antecedent, total_key_depth, q_filter_width, q_padding, "q", vars_3d_num_heads=vars_3d_num_heads) k = cache["k_encdec"] v = cache["v_encdec"] else: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) decode_loop_step = kwargs.get("decode_loop_step") if decode_loop_step is None: k = cache["k"] = tf.concat([cache["k"], k], axis=2) v = cache["v"] = tf.concat([cache["v"], v], axis=2) else: # Inplace update is required for inference on TPU. # Inplace_ops only supports inplace_update on the first dimension. # The performance of current implementation is better than updating # the tensor by adding the result of matmul(one_hot, # update_in_current_step) tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3]) tmp_k = inplace_ops.alias_inplace_update( tmp_k, decode_loop_step, tf.squeeze(k, axis=2)) k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3]) tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3]) tmp_v = inplace_ops.alias_inplace_update( tmp_v, decode_loop_step, tf.squeeze(v, axis=2)) v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3]) q = common_attention.split_heads(q, num_heads) if cache is None: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads if not vars_3d: if scale_dotproduct: q *= key_depth_per_head**-0.5 additional_returned_value = None if callable( attention_type): # Generic way to extend multihead_attention x = attention_type(q, k, v, **kwargs) if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif attention_type == "dot_product": x = common_attention.dot_product_attention( q, k, v, bias, dropout_rate, image_shapes, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims) elif attention_type == "dot_product_relative": x = common_attention.dot_product_attention_relative( q, k, v, bias, max_relative_position, dropout_rate, image_shapes, make_image_summary=make_image_summary) elif attention_type == "dot_product_relative_v2": x = common_attention.dot_product_self_attention_relative_v2( q, k, v, bias, max_length, dropout_rate, image_shapes, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims) elif attention_type == "local_within_block_mask_right": x = common_attention.masked_within_block_local_attention_1d( q, k, v, block_length=block_length) elif attention_type == "rel_local_mask_right": x = common_attention.masked_rel_local_attention_1d( q, k, v, block_length=block_length, make_image_summary=make_image_summary, dropout_rate=dropout_rate, share_rel_embed=shared_rel) elif attention_type == "local_mask_right": x = common_attention.masked_local_attention_1d( q, k, v, block_length=block_length, make_image_summary=make_image_summary) elif attention_type == "local_unmasked": x = common_attention.local_attention_1d(q, k, v, block_length=block_length, filter_width=block_width) elif attention_type == "masked_dilated_1d": x = common_attention.masked_dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) else: assert attention_type == "unmasked_dilated_1d" x = common_attention.dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) x = common_attention.combine_heads(x) # Set last dim specifically. x.set_shape(x.shape.as_list()[:-1] + [total_value_depth]) if vars_3d: o_var = tf.get_variable( "o", [num_heads, total_value_depth // num_heads, output_depth]) o_var = tf.cast(o_var, x.dtype) o_var = tf.reshape(o_var, [total_value_depth, output_depth]) x = tf.tensordot(x, o_var, axes=1) else: x = common_layers.dense(x, output_depth, use_bias=False, name="output_transform") if additional_returned_value is not None: return x, additional_returned_value return x
def multihead_mpnn_attention(node_states, total_key_depth, total_value_depth, output_depth, num_heads, adjacency_matrix=None, num_edge_types=5, ignore_zero=True, name="mpnn_attention"): """Multihead scaled-dot-product attention with input/output transformations. Args: node_states: A tensor of shape [batch, length, depth] total_key_depth: An integer for key dimension total_value_depth: An integer for value dimensions output_depth: An intger for output dimemsions num_heads: An integer adjacency_matrix: An tensor of ints of shape [batch, length, length] num_edge_types: An integer indicating number of edge bins ignore_zero: A flag that says that edge type 0 should be ignored name: A string Returns: The result of the attention transformation. The output shape is [batch_size, length_q, output_depth] unless the cache dict is provided in which case only the last memory position is calculated and the output shape is [batch_size, 1, hidden_dim] Optionaly returns an additional loss parameters (ex: load balance loss for the experts) returned by the attention_type function. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) with tf.variable_scope(name, default_name="multihead_mpnn_attention", values=[node_states]): q, k, v = compute_mpnn_qkv(node_states, total_key_depth, total_value_depth, num_edge_types, ignore_zero=ignore_zero) # reshaping k and v for head splitting q_shape = tf.shape(q) q = common_attention.split_heads(q, num_heads) k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 # make the heads dimension leading. We will loop over heads. q = tf.transpose(q, [1, 0, 2, 3]) k = tf.transpose(k, [1, 0, 2, 3]) v = tf.transpose(v, [1, 0, 2, 3]) # putting edge as the dimension after batch for k and v # k and v will be [heads, batch, num_edge_types, length, depth] k = tf.reshape(k, [ num_heads, q_shape[0], q_shape[1], num_edge_types, total_key_depth // num_heads ]) k = tf.transpose(k, [0, 1, 3, 2, 4]) v = tf.reshape(v, [ num_heads, q_shape[0], q_shape[1], num_edge_types, total_value_depth // num_heads ]) v = tf.transpose(v, [0, 1, 3, 2, 4]) # doing attention separately for each head head_outputs = [] for head_id in range(num_heads): output = dot_product_mpnn_attention(q[head_id], k[head_id], v[head_id], adjacency_matrix, num_edge_types) head_outputs.append(tf.expand_dims(output, axis=0)) # making x = [heads, batch, length, total_value_depth//num_heads] x = tf.concat(head_outputs, axis=0) x = tf.transpose(x, [1, 0, 2, 3]) # making x [batch, length, depth] x = common_attention.combine_heads(x) x = common_layers.dense(x, output_depth, use_bias=False, name="output_transform") return x
def multihead_attention(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, shared_rel=False, max_relative_position=None, image_shapes=None, attention_type="dot_product", block_length=128, block_width=128, q_filter_width=1, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", cache=None, gap_size=0, num_memory_blocks=2, name="multihead_attention", save_weights_to=None, make_image_summary=True, dropout_broadcast_dims=None, max_length=None, vars_3d=False, scale_dotproduct=True, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] or None bias: bias Tensor (see attention_bias()) total_key_depth: an integer total_value_depth: an integer output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number shared_rel: boolean to share relative embeddings max_relative_position: Maximum distance between inputs to generate unique relation embeddings for. Only relevant when using "dot_product_relative" attention. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() attention_type: a string, either "dot_product", "dot_product_relative", "local_mask_right", "local_unmasked", "masked_dilated_1d", "unmasked_dilated_1d", graph, or any attention function with the signature (query, key, value, **kwargs) block_length: an integer - relevant for "local_mask_right" block_width: an integer - relevant for "local_unmasked" q_filter_width: An integer specifying how wide you want the query to be. kv_filter_width: An integer specifying how wide you want the keys and values to be. q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID": no padding. cache: dict containing Tensors which are the results of previous attentions, used for fast decoding. Expects the dict to contrain two keys ('k' and 'v'), for the initial call the values for these keys should be empty Tensors of the appropriate shape. 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] gap_size: Integer option for dilated attention to indicate spacing between memory blocks. num_memory_blocks: Integer option to indicate how many memory blocks to look at. name: an optional string. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. max_length: an integer - needed by relative attention vars_3d: use 3-dimensional variables for input/output transformations scale_dotproduct: whether to normalize the attention product. **kwargs (dict): Parameters for the attention function Caching: WARNING: For decoder self-attention, i.e. when memory_antecedent == None, the caching assumes that the bias contains future masking. The caching works by saving all the previous key and value values so that you are able to send just the last query location to this attention function. I.e. if the cache dict is provided it assumes the query is of the shape [batch_size, 1, hidden_dim] rather than the full memory. Returns: The result of the attention transformation. The output shape is [batch_size, length_q, hidden_dim] unless the cache dict is provided in which case only the last memory position is calculated and the output shape is [batch_size, 1, hidden_dim] Optionally returns an additional loss parameters (ex: load balance loss for the experts) returned by the attention_type function. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) vars_3d_num_heads = num_heads if vars_3d else 0 with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): if cache is None or memory_antecedent is None: q, k, v = common_attention.compute_qkv( query_antecedent, memory_antecedent, total_key_depth, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding, vars_3d_num_heads=vars_3d_num_heads) if cache is not None: if attention_type != "dot_product": # TODO(petershaw): Support caching when using relative position # representations, i.e. "dot_product_relative" attention. raise NotImplementedError( "Caching is not guaranteed to work with attention types other than" " dot_product.") if bias is None: raise ValueError("Bias required for caching. See function docstring " "for details.") if memory_antecedent is not None: # Encoder-Decoder Attention Cache q = common_attention.compute_attention_component( query_antecedent, total_key_depth, q_filter_width, q_padding, "q", vars_3d_num_heads=vars_3d_num_heads) k = cache["k_encdec"] v = cache["v_encdec"] else: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) decode_loop_step = kwargs.get("decode_loop_step") if decode_loop_step is None: k = cache["k"] = tf.concat([cache["k"], k], axis=2) v = cache["v"] = tf.concat([cache["v"], v], axis=2) else: # Inplace update is required for inference on TPU. # Inplace_ops only supports inplace_update on the first dimension. # The performance of current implementation is better than updating # the tensor by adding the result of matmul(one_hot, # update_in_current_step) tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3]) tmp_k = inplace_ops.alias_inplace_update( tmp_k, decode_loop_step, tf.squeeze(k, axis=2)) k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3]) tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3]) tmp_v = inplace_ops.alias_inplace_update( tmp_v, decode_loop_step, tf.squeeze(v, axis=2)) v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3]) q = common_attention.split_heads(q, num_heads) if cache is None: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads if not vars_3d: if scale_dotproduct: q *= key_depth_per_head**-0.5 additional_returned_value = None if callable(attention_type): # Generic way to extend multihead_attention x = attention_type(q, k, v, **kwargs) if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif attention_type == "dot_product": x = common_attention.dot_product_attention( q, k, v, bias, dropout_rate, image_shapes, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims) elif attention_type == "dot_product_relative": x = common_attention.dot_product_attention_relative( q, k, v, bias, max_relative_position, dropout_rate, image_shapes, make_image_summary=make_image_summary) elif attention_type == "dot_product_relative_v2": x = common_attention.dot_product_self_attention_relative_v2( q, k, v, bias, max_length, dropout_rate, image_shapes, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims) elif attention_type == "local_within_block_mask_right": x = common_attention.masked_within_block_local_attention_1d( q, k, v, block_length=block_length) elif attention_type == "rel_local_mask_right": x = common_attention.masked_rel_local_attention_1d( q, k, v, block_length=block_length, make_image_summary=make_image_summary, dropout_rate=dropout_rate, share_rel_embed=shared_rel) elif attention_type == "local_mask_right": x = common_attention.masked_local_attention_1d( q, k, v, block_length=block_length, make_image_summary=make_image_summary) elif attention_type == "local_unmasked": x = common_attention.local_attention_1d( q, k, v, block_length=block_length, filter_width=block_width) elif attention_type == "masked_dilated_1d": x = common_attention.masked_dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) else: assert attention_type == "unmasked_dilated_1d" x = common_attention.dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) x = common_attention.combine_heads(x) # Set last dim specifically. x.set_shape(x.shape.as_list()[:-1] + [total_value_depth]) if vars_3d: o_var = tf.get_variable( "o", [num_heads, total_value_depth // num_heads, output_depth]) o_var = tf.cast(o_var, x.dtype) o_var = tf.reshape(o_var, [total_value_depth, output_depth]) x = tf.tensordot(x, o_var, axes=1) else: x = common_layers.dense( x, output_depth, use_bias=False, name="output_transform") if additional_returned_value is not None: return x, additional_returned_value return x
def multihead_attention_qkv(query_antecedent, key_antecedent, value_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, max_relative_position=None, image_shapes=None, attention_type="dot_product", block_length=128, block_width=128, q_filter_width=1, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", cache=None, gap_size=0, num_memory_blocks=2, attention_order=1, name=None, **kwargs): """Multihead scaled-dot-product attention with separate key and value inputs rather than a single memory input.input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] ... attention_order (int): For high order attention like dot_product_highorder (rest: see common_attention.multihead_attention) """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) with tf.variable_scope( name, default_name="multihead_attention", values=[query_antecedent, key_antecedent, value_antecedent]): if value_antecedent is None: q, k, v = common_attention.compute_qkv( query_antecedent, key_antecedent, total_key_depth, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding) else: q, k, v = transform_qkv(query_antecedent, key_antecedent, value_antecedent, total_key_depth, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding) if cache is not None: if attention_type != "dot_product": raise NotImplementedError( "Caching is not guaranteed to work with attention types other than" " dot_product.") if bias is None: raise ValueError( "Bias required for caching. See function docstring " "for details.") k = cache["k"] = tf.concat([cache["k"], k], axis=1) v = cache["v"] = tf.concat([cache["v"], v], axis=1) q = common_attention.split_heads(q, num_heads) k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 if "," in attention_type: num_types = attention_type.count(",") + 1 qs = tf.split(q, num_types, axis=1) ks = tf.split(k, num_types, axis=1) vs = tf.split(v, num_types, axis=1) key_depth_per_head = total_key_depth // num_heads // num_types else: qs = [q] ks = [k] vs = [v] key_depth_per_head = total_key_depth // num_heads additional_returned_value = None xs = [] for q, k, v, att_type in zip(qs, ks, vs, attention_type.split(",")): q *= key_depth_per_head**-0.5 if callable(att_type): # Generic way to extend multihead_attention x = att_type(q, k, v, **kwargs) if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif att_type == "dot_product": x = common_attention.dot_product_attention( q, k, v, bias, dropout_rate, image_shapes) elif att_type == "dot_product_highorder": x = dot_product_highorder_attention( q, k, v, bias, dropout_rate, image_shapes, attention_order=attention_order) elif att_type == "dot_product_highorder_shared": x = dot_product_highorder_shared_attention( q, k, v, bias, dropout_rate, image_shapes, attention_order=attention_order) elif att_type == "dot_product_relative": x = common_attention.dot_product_attention_relative( q, k, v, bias, max_relative_position, dropout_rate, image_shapes) elif att_type == "local_mask_right": x = common_attention.masked_local_attention_1d( q, k, v, block_length=block_length) elif att_type == "local_unmasked": x = common_attention.local_attention_1d( q, k, v, block_length=block_length, filter_width=block_width) elif att_type == "masked_dilated_1d": x = common_attention.masked_dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) else: assert att_type == "unmasked_dilated_1d" x = common_attention.dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) xs.append(x) x = xs[0] if len(xs) == 1 else tf.concat(xs, axis=1) x = common_attention.combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") if additional_returned_value is not None: return x, additional_returned_value return x
def multihead_mpnn_attention(node_states, total_key_depth, total_value_depth, output_depth, num_heads, adjacency_matrix=None, num_edge_types=5, num_transforms=None, use_weighted_sum=False, name="mpnn_attention"): """Multihead scaled-dot-product attention with input/output transformations. Let B be the number of batches. Let N be the number of nodes in the graph. Let D be the size of the node hidden states. Let K be the size of the attention keys/queries (total_key_depth). Let V be the size of the attention values (total_value_depth). Let O be the size of the attention output (output_depth). Let H be the number of heads (num_heads). Let T be the total number of transforms (num_transforms). The key and value depths are split across all of the heads. For example, if the key depth is 6 and there are three heads, then the key for each head has depth 2. Args: node_states: A Tensor with shape [B, N, D] total_key_depth: An integer (K). total_value_depth: An integer (V). output_depth: An integer (O). num_heads: An integer (H). adjacency_matrix: An Tensor of ints with shape [B, T, N, N]. If there is an edge from node j to node i in batch b, then adjacency_matrix[b, i, j] contains the type of that edge as an integer. Otherwise, it contains 0. num_edge_types: An integer indicating number of edge types. num_transforms: An integer indicating number of transforms (T). If None, then num_transforms will be equal to num_edge_types. use_weighted_sum: If False, will only use a single transform per edge type. Otherwise, use a learned weighted sum of transforms per edge type. name: A string. Returns: The result of the attention transformation. The output shape is [B, N, O]. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) with tf.variable_scope(name, default_name="multihead_mpnn_attention", values=[node_states]): # If not explicitly set, use num_transforms set to num_edge_types. num_transforms = (num_edge_types if num_transforms is None else num_transforms) # Create the query for each node's incoming edges. # Create the keys/values for each node for each possible outgoing edge type. q, k, v = compute_mpnn_qkv(node_states, total_key_depth, total_value_depth, num_transforms) q_shape = tf.shape(q) # As above, q_shape is [B, N, K]. # Divides each query/key/value into separate heads. Specifically, the # query/key/value for each (batch, node) pair (i.e., the third dimensions # of q, k, and v) are broken into H separate pieces. These pieces are used # as the separate attention heads. The resulting tensors have shape # [B, H, N, ?/H], where ? = K, K*T or V*T as appropriate. q = common_attention.split_heads(q, num_heads) # Shape [B, H, N, K/H]. k = common_attention.split_heads(k, num_heads) # Shape [B, H, N, K*T/H]. v = common_attention.split_heads(v, num_heads) # Shape [B, H, N, V*T/H]. key_depth_per_head = total_key_depth // num_heads # Ensures that the logits don't have too large of a magnitude. q *= key_depth_per_head**-0.5 # Rearrange the dimensions so that the head is first. This will make # subsequent steps easier (we loop over the head). q = tf.transpose(q, [1, 0, 2, 3]) # Shape [H, B, N, K/H]. k = tf.transpose(k, [1, 0, 2, 3]) # Shape [H, B, N, K*T/H]. v = tf.transpose(v, [1, 0, 2, 3]) # Shape [H, B, N, V*T/H]. # Split the keys and values into separate per-edge-type keys and values. k = tf.reshape(k, [ num_heads, q_shape[0], q_shape[1], num_transforms, total_key_depth // num_heads ]) # Shape [H, B, N, T, K/H]. k = tf.transpose(k, [0, 1, 3, 2, 4]) # Shape [H, B, T, N, K/H]. v = tf.reshape(v, [ num_heads, q_shape[0], q_shape[1], num_transforms, total_value_depth // num_heads ]) # Shape [H, B, N, T, V/H]. v = tf.transpose(v, [0, 1, 3, 2, 4]) # Shape [H, B, T, N, V/H]. # Perform attention for each head and combine the results into a list. # head_outputs stores a list of tensors, each with shape [1, B, N, V/H]. # The last dimension contains the values computed for each attention head. # Each value was determined by computing attention over all of the # incoming edges for node n, weighting the incoming values accordingly, # and adding those weighted values together. head_outputs = [] for head_id in range(num_heads): output = dot_product_mpnn_attention( q[head_id], k[head_id], v[head_id], adjacency_matrix, num_edge_types, num_transforms=num_transforms, use_weighted_sum=use_weighted_sum) # Store this result in the list of attention results for each head. # The call to expand_dims gives output shape [1, B, N, V/H], which will # come in handy when we combine the heads together. head_outputs.append(tf.expand_dims(output, axis=0)) # Combine the heads together into one tensor and rearrange the dimensions. x = tf.concat(head_outputs, axis=0) # Shape [H, B, N, V/H]. x = tf.transpose(x, [1, 0, 2, 3]) # Shape [B, H, N, V/H]. # Concatenate the values produced by each head together into one vector. x = common_attention.combine_heads(x) # Shape [B, N, V]. # A fully-connected linear layer to convert from the value vectors of size V # to output vectors of length O (the appropriate output length). x = common_layers.dense(x, output_depth, use_bias=False, name="output_transform") return x
def multihead_graph_attention(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, image_shapes=None, attention_type="edge_vector", name="multihead_graph_attention", save_weights_to=None, make_image_summary=True, dropout_broadcast_dims=None, adjacency_matrix=None, num_edge_types=5, vars_3d=False, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] or None bias: bias Tensor (see attention_bias()) total_key_depth: an integer total_value_depth: an integer output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() attention_type: a string, either "dot_product", "dot_product_relative", "local_mask_right", "local_unmasked", "masked_dilated_1d", "unmasked_dilated_1d", graph, or any attention function with the signature (query, key, value, **kwargs) name: an optional string. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. adjacency_matrix: an optional tensor of shape [batch, len_q, len_q] containing edge vectors for attention num_edge_types: number of edge types, an int vars_3d: use 3-dimensional variables for input/output transformations **kwargs (dict): Parameters for the attention function Returns: The result of the attention transformation. The output shape is [batch_size, length_q, output_depth] Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) vars_3d_num_heads = num_heads if vars_3d else None with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): q, k, v = common_attention.compute_qkv( query_antecedent, memory_antecedent, total_key_depth, total_value_depth, vars_3d_num_heads=vars_3d_num_heads) q = common_attention.split_heads(q, num_heads) k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads if not vars_3d: q *= key_depth_per_head**-0.5 additional_returned_value = None if callable( attention_type): # Generic way to extend multihead_attention x = attention_type(q, k, v, **kwargs) if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif attention_type == "edge_vector": x = graph_attention(q, k, v, bias, dropout_rate, image_shapes, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims, adjacency_matrix=adjacency_matrix, num_edge_types=num_edge_types) x = common_attention.combine_heads(x) # Set last dim specifically. x.set_shape(x.shape.as_list()[:-1] + [total_value_depth]) if vars_3d: o_var = tf.get_variable( "o", [num_heads, total_value_depth // num_heads, output_depth]) o_var = tf.reshape(o_var, [total_value_depth, output_depth]) x = tf.tensordot(x, o_var, axes=1) else: x = common_layers.dense(x, output_depth, use_bias=False, name="output_transform") if additional_returned_value is not None: return x, additional_returned_value return x
def multihead_attention(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, attention_type="dot_product", image_shapes=None, q_filter_width=1, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", cache=None, name="multihead_attention", save_weights_to=None, make_image_summary=True, dropout_broadcast_dims=None, vars_3d=False, sparsity_technique=None, threshold=3.0, training=True, clip_alpha=None, initial_sparsity=None, split_heads=False, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] or None bias: bias Tensor (see attention_bias()) total_key_depth: an integer total_value_depth: an integer output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number attention_type: a string, either "dot_product", "dot_product_relative", "local_mask_right", "local_unmasked", "masked_dilated_1d", "unmasked_dilated_1d", graph, or any attention function with the signature (query, key, value, **kwargs) image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() q_filter_width: An integer specifying how wide you want the query to be. kv_filter_width: An integer specifying how wide you want the keys and values to be. q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID": no padding. cache: dict containing Tensors which are the results of previous attentions, used for fast decoding. Expects the dict to contrain two keys ('k' and 'v'), for the initial call the values for these keys should be empty Tensors of the appropriate shape. 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] name: an optional string. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. vars_3d: use 3-dimensional variables for input/output transformations sparsity_technique: technique used for sparsifying weights. threshold: log alpha threshold used for evaluation with variational dropout. training: whether model is being trained or not. clip_alpha: alpha clipping threshold for variational dropout. initial_sparsity: initial sparsity level for lottery ticket & scratch experiments. split_heads: Whether to prune each head separately. **kwargs (dict): Parameters for the attention function Caching: WARNING: For decoder self-attention, i.e. when memory_antecedent == None, the caching assumes that the bias contains future masking. The caching works by saving all the previous key and value values so that you are able to send just the last query location to this attention function. I.e. if the cache dict is provided it assumes the query is of the shape [batch_size, 1, hidden_dim] rather than the full memory. Returns: The result of the attention transformation. The output shape is [batch_size, length_q, hidden_dim] unless the cache dict is provided in which case only the last memory position is calculated and the output shape is [batch_size, 1, hidden_dim] Optionally returns an additional loss parameters (ex: load balance loss for the experts) returned by the attention_type function. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) if vars_3d: raise ValueError("3d attention variables not supported.") if attention_type != "dot_product": raise ValueError( "Sparse multihead attention only supports dot_product attention.") vars_3d_num_heads = 0 with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): if cache is None or memory_antecedent is None: q, k, v = compute_qkv(query_antecedent, memory_antecedent, total_key_depth, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding, vars_3d_num_heads=vars_3d_num_heads, sparsity_technique=sparsity_technique, threshold=threshold, training=training, clip_alpha=clip_alpha, initial_sparsity=initial_sparsity, split_heads=split_heads, num_heads=num_heads) if cache is not None: if bias is None: raise ValueError( "Bias required for caching. See function docstring " "for details.") if memory_antecedent is not None: # Encoder-Decoder Attention Cache q = compute_attention_component( query_antecedent, total_key_depth, q_filter_width, q_padding, "q", vars_3d_num_heads=vars_3d_num_heads, sparsity_technique=sparsity_technique, threshold=threshold, training=training, clip_alpha=clip_alpha, initial_sparsity=initial_sparsity, split_heads=split_heads, num_heads=num_heads) k = cache["k_encdec"] v = cache["v_encdec"] else: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) decode_loop_step = kwargs.get("decode_loop_step") if decode_loop_step is None: k = cache["k"] = tf.concat([cache["k"], k], axis=2) v = cache["v"] = tf.concat([cache["v"], v], axis=2) else: # Inplace update is required for inference on TPU. # Inplace_ops only supports inplace_update on the first dimension. # The performance of current implementation is better than updating # the tensor by adding the result of matmul(one_hot, # update_in_current_step) tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3]) tmp_k = inplace_ops.alias_inplace_update( tmp_k, decode_loop_step, tf.squeeze(k, axis=2)) k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3]) tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3]) tmp_v = inplace_ops.alias_inplace_update( tmp_v, decode_loop_step, tf.squeeze(v, axis=2)) v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3]) q = common_attention.split_heads(q, num_heads) if cache is None: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads if not vars_3d: q *= key_depth_per_head**-0.5 # compute the attention x = common_attention.dot_product_attention( q, k, v, bias, dropout_rate, image_shapes, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims) x = common_attention.combine_heads(x) # Set last dim specifically. x.set_shape(x.shape.as_list()[:-1] + [total_value_depth]) if sparsity_technique: x = common_sparse.dense(x, output_depth, use_bias=False, sparsity_technique=sparsity_technique, threshold=threshold, training=training, clip_alpha=clip_alpha, name="output_transform", initial_sparsity=initial_sparsity) else: x = common_layers.dense(x, output_depth, use_bias=False, name="output_transform") return x
def multihead_mpnn_attention(node_states, total_key_depth, total_value_depth, output_depth, num_heads, adjacency_matrix=None, num_edge_types=5, num_transforms=None, use_weighted_sum=False, name="mpnn_attention"): """Multihead scaled-dot-product attention with input/output transformations. Let B be the number of batches. Let N be the number of nodes in the graph. Let D be the size of the node hidden states. Let K be the size of the attention keys/queries (total_key_depth). Let V be the size of the attention values (total_value_depth). Let O be the size of the attention output (output_depth). Let H be the number of heads (num_heads). Let T be the total number of transforms (num_transforms). The key and value depths are split across all of the heads. For example, if the key depth is 6 and there are three heads, then the key for each head has depth 2. Args: node_states: A Tensor with shape [B, N, D] total_key_depth: An integer (K). total_value_depth: An integer (V). output_depth: An integer (O). num_heads: An integer (H). adjacency_matrix: An Tensor of ints with shape [B, T, N, N]. If there is an edge from node j to node i in batch b, then adjacency_matrix[b, i, j] contains the type of that edge as an integer. Otherwise, it contains 0. num_edge_types: An integer indicating number of edge types. num_transforms: An integer indicating number of transforms (T). If None, then num_transforms will be equal to num_edge_types. use_weighted_sum: If False, will only use a single transform per edge type. Otherwise, use a learned weighted sum of transforms per edge type. name: A string. Returns: The result of the attention transformation. The output shape is [B, N, O]. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) with tf.variable_scope( name, default_name="multihead_mpnn_attention", values=[node_states]): # If not explicitly set, use num_transforms set to num_edge_types. num_transforms = ( num_edge_types if num_transforms is None else num_transforms) # Create the query for each node's incoming edges. # Create the keys/values for each node for each possible outgoing edge type. q, k, v = compute_mpnn_qkv( node_states, total_key_depth, total_value_depth, num_transforms) q_shape = tf.shape(q) # As above, q_shape is [B, N, K]. # Divides each query/key/value into separate heads. Specifically, the # query/key/value for each (batch, node) pair (i.e., the third dimensions # of q, k, and v) are broken into H separate pieces. These pieces are used # as the separate attention heads. The resulting tensors have shape # [B, H, N, ?/H], where ? = K, K*T or V*T as appropriate. q = common_attention.split_heads(q, num_heads) # Shape [B, H, N, K/H]. k = common_attention.split_heads(k, num_heads) # Shape [B, H, N, K*T/H]. v = common_attention.split_heads(v, num_heads) # Shape [B, H, N, V*T/H]. key_depth_per_head = total_key_depth // num_heads # Ensures that the logits don't have too large of a magnitude. q *= key_depth_per_head**-0.5 # Rearrange the dimensions so that the head is first. This will make # subsequent steps easier (we loop over the head). q = tf.transpose(q, [1, 0, 2, 3]) # Shape [H, B, N, K/H]. k = tf.transpose(k, [1, 0, 2, 3]) # Shape [H, B, N, K*T/H]. v = tf.transpose(v, [1, 0, 2, 3]) # Shape [H, B, N, V*T/H]. # Split the keys and values into separate per-edge-type keys and values. k = tf.reshape(k, [ num_heads, q_shape[0], q_shape[1], num_transforms, total_key_depth // num_heads ]) # Shape [H, B, N, T, K/H]. k = tf.transpose(k, [0, 1, 3, 2, 4]) # Shape [H, B, T, N, K/H]. v = tf.reshape(v, [ num_heads, q_shape[0], q_shape[1], num_transforms, total_value_depth // num_heads ]) # Shape [H, B, N, T, V/H]. v = tf.transpose(v, [0, 1, 3, 2, 4]) # Shape [H, B, T, N, V/H]. # Perform attention for each head and combine the results into a list. # head_outputs stores a list of tensors, each with shape [1, B, N, V/H]. # The last dimension contains the values computed for each attention head. # Each value was determined by computing attention over all of the # incoming edges for node n, weighting the incoming values accordingly, # and adding those weighted values together. head_outputs = [] for head_id in range(num_heads): output = dot_product_mpnn_attention( q[head_id], k[head_id], v[head_id], adjacency_matrix, num_edge_types, num_transforms=num_transforms, use_weighted_sum=use_weighted_sum) # Store this result in the list of attention results for each head. # The call to expand_dims gives output shape [1, B, N, V/H], which will # come in handy when we combine the heads together. head_outputs.append(tf.expand_dims(output, axis=0)) # Combine the heads together into one tensor and rearrange the dimensions. x = tf.concat(head_outputs, axis=0) # Shape [H, B, N, V/H]. x = tf.transpose(x, [1, 0, 2, 3]) # Shape [B, H, N, V/H]. # Concatenate the values produced by each head together into one vector. x = common_attention.combine_heads(x) # Shape [B, N, V]. # A fully-connected linear layer to convert from the value vectors of size V # to output vectors of length O (the appropriate output length). x = common_layers.dense( x, output_depth, use_bias=False, name="output_transform") return x
def multihead_graph_attention(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, image_shapes=None, attention_type="edge_vector", name="multihead_graph_attention", save_weights_to=None, make_image_summary=True, dropout_broadcast_dims=None, adjacency_matrix=None, num_edge_types=5, vars_3d=False, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] or None bias: bias Tensor (see attention_bias()) total_key_depth: an integer total_value_depth: an integer output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() attention_type: a string, either "dot_product", "dot_product_relative", "local_mask_right", "local_unmasked", "masked_dilated_1d", "unmasked_dilated_1d", graph, or any attention function with the signature (query, key, value, **kwargs) name: an optional string. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. adjacency_matrix: an optional tensor of shape [batch, len_q, len_q] containing edge vectors for attention num_edge_types: number of edge types, an int vars_3d: use 3-dimensional variables for input/output transformations **kwargs (dict): Parameters for the attention function Returns: The result of the attention transformation. The output shape is [batch_size, length_q, output_depth] Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) vars_3d_num_heads = num_heads if vars_3d else None with tf.variable_scope( name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): q, k, v = common_attention.compute_qkv( query_antecedent, memory_antecedent, total_key_depth, total_value_depth, vars_3d_num_heads=vars_3d_num_heads) q = common_attention.split_heads(q, num_heads) k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads if not vars_3d: q *= key_depth_per_head**-0.5 additional_returned_value = None if callable(attention_type): # Generic way to extend multihead_attention x = attention_type(q, k, v, **kwargs) if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif attention_type == "edge_vector": x = graph_attention( q, k, v, bias, dropout_rate, image_shapes, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims, adjacency_matrix=adjacency_matrix, num_edge_types=num_edge_types) x = common_attention.combine_heads(x) # Set last dim specifically. x.set_shape(x.shape.as_list()[:-1] + [total_value_depth]) if vars_3d: o_var = tf.get_variable( "o", [num_heads, total_value_depth // num_heads, output_depth]) o_var = tf.reshape(o_var, [total_value_depth, output_depth]) x = tf.tensordot(x, o_var, axes=1) else: x = common_layers.dense( x, output_depth, use_bias=False, name="output_transform") if additional_returned_value is not None: return x, additional_returned_value return x
def multihead_attention(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, max_relative_position=None, image_shapes=None, attention_type="dot_product", block_length=128, block_width=128, q_filter_width=1, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", cache=None, gap_size=0, num_memory_blocks=2, name=None, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] or None bias: bias Tensor (see attention_bias()) total_key_depth: an integer total_value_depth: an integer output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number max_relative_position: Maximum distance between inputs to generate unique relation embeddings for. Only relevant when using "dot_product_relative" attention. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() attention_type: a string, either "dot_product", "dot_product_relative", "local_mask_right", "local_unmasked", "masked_dilated_1d", "unmasked_dilated_1d" or any attention function with the signature (query, key, value, **kwargs) block_length: an integer - relevant for "local_mask_right" block_width: an integer - relevant for "local_unmasked" q_filter_width: An integer specifying how wide you want the query to be. kv_filter_width: An integer specifying how wide you want the keys and values to be. q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID": no padding. cache: dict containing Tensors which are the results of previous attentions, used for fast decoding. Expects the dict to contrain two keys ('k' and 'v'), for the initial call the values for these keys should be empty Tensors of the appropriate shape. 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] gap_size: Integer option for dilated attention to indicate spacing between memory blocks. num_memory_blocks: Integer option to indicate how many memory blocks to look at. name: an optional string **kwargs (dict): Parameters for the attention function Caching: WARNING: For decoder self-attention, i.e. when memory_antecedent == None, the caching assumes that the bias contains future masking. The caching works by saving all the previous key and value values so that you are able to send just the last query location to this attention function. I.e. if the cache dict is provided it assumes the query is of the shape [batch_size, 1, hiddem_dim] rather than the full memory. Returns: The result of the attention transformation. The output shape is [batch_size, length_q, hidden_dim] unless the cache dict is provided in which case only the last memory position is calculated and the output shape is [batch_size, 1, hidden_dim] Optionaly returns an additional loss parameters (ex: load balance loss for the experts) returned by the attention_type function. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): if cache is None: q, k, v = common_attention.compute_qkv( query_antecedent, memory_antecedent, total_key_depth, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding) else: q = compute_q(query_antecedent, total_key_depth, q_filter_width, q_padding) k, v = cache['k_encdec'], cache['v_encdec'] q = common_attention.split_heads(q, num_heads) k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 additional_returned_value = None if callable( attention_type): # Generic way to extend multihead_attention x = attention_type(q, k, v, **kwargs) if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif attention_type == "dot_product": x = common_attention.dot_product_attention(q, k, v, bias, dropout_rate, image_shapes) elif attention_type == "dot_product_relative": x = common_attention.dot_product_attention_relative( q, k, v, bias, max_relative_position, dropout_rate, image_shapes) elif attention_type == "local_mask_right": x = common_attention.masked_local_attention_1d( q, k, v, block_length=block_length) elif attention_type == "local_unmasked": x = common_attention.local_attention_1d(q, k, v, block_length=block_length, filter_width=block_width) elif attention_type == "masked_dilated_1d": x = common_attention.masked_dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) else: assert attention_type == "unmasked_dilated_1d" x = common_attention.dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) x = common_attention.combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") if additional_returned_value is not None: return x, additional_returned_value return x
def multihead_attention_pos(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, max_relative_position=None, image_shapes=None, attention_type="dot_product", block_length=128, block_width=128, qkv_padding="VALID", cache=None, gap_size=0, num_memory_blocks=2, name=None, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. Caching: WARNING: For decoder self-attention, i.e. when memory_antecedent == None, the caching assumes that the bias contains future masking. The caching works by saving all the previous key and value values so that you are able to send just the last query location to this attention function. I.e. if the cache dict is provided it assumes the query is of the shape [batch_size, 1, hiddem_dim] rather than the full memory. Returns: The result of the attention transformation. The output shape is [batch_size, length_q, hidden_dim] unless the cache dict is provided in which case only the last memory position is calculated and the output shape is [batch_size, 1, hidden_dim] Optionaly returns an additional loss parameters (ex: load balance loss for the experts) returned by the attention_type function. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): q, k, v = compute_qkv_pos(query_antecedent, memory_antecedent, total_key_depth, total_value_depth, qkv_padding) if cache is not None: if attention_type != "dot_product": raise NotImplementedError( "Caching is not guaranteed to work with attention types other than" " dot_product.") if bias is None: raise ValueError( "Bias required for caching. See function docstring " "for details.") k = cache["k"] = tf.concat([cache["k"], k], axis=1) v = cache["v"] = tf.concat([cache["v"], v], axis=1) q = common_attention.split_heads(q, num_heads) k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 additional_returned_value = None if callable( attention_type): # Generic way to extend multihead_attention x = attention_type(q, k, v, **kwargs) if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif attention_type == "dot_product": x = common_attention.dot_product_attention(q, k, v, bias, dropout_rate, image_shapes) elif attention_type == "dot_product_relative": x = common_attention.dot_product_attention_relative( q, k, v, bias, max_relative_position, dropout_rate, image_shapes) elif attention_type == "local_mask_right": x = common_attention.masked_local_attention_1d( q, k, v, block_length=block_length) elif attention_type == "local_unmasked": x = common_attention.local_attention_1d(q, k, v, block_length=block_length, filter_width=block_width) elif attention_type == "masked_dilated_1d": x = common_attention.masked_dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) else: assert attention_type == "unmasked_dilated_1d" x = common_attention.dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) x = common_attention.combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") if additional_returned_value is not None: return x, additional_returned_value return x
def multihead_attention_osm(query_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, max_relative_position=None, attention_type="dot_product", block_length=128, block_width=128, q_filter_width=1, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", cache=None, gap_size=0, num_memory_blocks=2, name=None, query_antecedent_raw=None, **kwargs): """Multihead scaled-dot-product attention with separate key and value inputs rather than a single memory input.input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length, channels] bias: [1, 1, length, length] bias Tensor (see attention_bias()) ... query_antecedent_raw: a int32 Tensor with shape [batch, length] (rest: see common_attention.multihead_attention) """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent]): q, k, v = compute_qkv_osm(query_antecedent, query_antecedent_raw, total_key_depth, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding) q = common_attention.split_heads(q, num_heads) k = split_heads_5d(k, num_heads) v = common_attention.split_heads(v, num_heads) # k has shape [batch, heads, length (time), length (annotaion), total_[key|value]_depth // num_heads] # q,v have shape [batch, heads, length, total_[key|value]_depth // num_heads] key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 x = dot_product_osm_attention(q, k, v, bias, dropout_rate) x = common_attention.combine_heads(x) x = tf.layers.dense(x, output_depth, use_bias=False, name="output_transform") return x