def dot_product_highorder_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None, attention_order=2, name=None, make_image_summary=True): """High order dot-product attention. Attention is applied repeatedly to generate query vectors. For example, 2-order attention uses q,k,v to generate a new query vector q'. The final attention result is computed with q',k,v. Args: q: a Tensor with shape [batch, heads, length_q, depth_k] k: a Tensor with shape [batch, heads, length_kv, depth_k] v: a Tensor with shape [batch, heads, length_kv, depth_v] bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() attention_order (int): Attention order (number of steps) name: an optional string make_image_summary: True if you want an image summary. Returns: A Tensor. """ if attention_order == 1: return common_attention.dot_product_attention( q, k, v, bias, dropout_rate, image_shapes, name=name, make_image_summary=make_image_summary) # Split q, k in attention_order pieces qs = tf.split(q, attention_order, axis=3) ks = tf.split(k, attention_order, axis=3) with tf.variable_scope(name, default_name="dot_product_highorder_attention", values=[q, k, v]): for idx in xrange(attention_order): # [batch, num_heads, query_length, memory_length] q = tf.matmul(weights, qs[idx]) if idx != 0 else qs[0] logits = tf.matmul(q, ks[idx], transpose_b=True) if bias is not None: logits += bias weights = tf.nn.softmax(logits, name="attention_weights") # dropping out the attention links for each of the heads weights = tf.nn.dropout(weights, 1.0 - dropout_rate) if (not tf.get_variable_scope().reuse and # Summaries don't work well within tf.while_loop() "/while/" not in tf.contrib.framework.get_name_scope() and make_image_summary): common_attention.attention_image_summary(weights, image_shapes) return tf.matmul(weights, v)
def dot_product_mpnn_attention(q, k, v, adjacency_matrix, num_edge_types, ignore_zero=True, name=None): """Dot product attention with edge vectors. Args: q: [batch, length, key_depth] tensor k: [batch, num_edge_types, length, key_depth] v: [batch, num_edge_types, length, depth] adjacency_matrix: [batch, length, length] tensor of int edge types num_edge_types: an int, specifying number of edge types ignore_zero: A flag that says that edge type 0 should be ignored name: optional string Returns: A tensor of shape [batch, length, depth(q)] """ with tf.variable_scope( name, default_name="dot_product_mpnn_attention", values=[q, k, v, adjacency_matrix, num_edge_types]): # Computing attention mask # all edge logits will have shape [batch, edge_types, len, len] all_edge_logits = tf.matmul( tf.tile(tf.expand_dims(q, axis=1), [1, num_edge_types, 1, 1]), k, transpose_b=True) # adjacency_matrix_one_hot has shape [batch, len, len, num_edge_types] adjacency_matrix_one_hot = tf.one_hot(adjacency_matrix, num_edge_types) # making adjacency_matrix_one_hot [batch, edge_types, len, len] adjacency_matrix_one_hot = tf.transpose(adjacency_matrix_one_hot, [0, 3, 1, 2]) # getting dot products for q_i, k_j, and e_{ij}. This assumes that for # edge type 0, the dot products are 0 all_edge_logits *= adjacency_matrix_one_hot # logits will be [batch, length, length] after reducing along # axis 1 which has dimension num_edge_types. logits = tf.reduce_sum(all_edge_logits, axis=1) # ignoring edges if needed bias = 0 if ignore_zero: bias = tf.to_float(tf.equal(adjacency_matrix, 0)) * -1e9 logits += bias # getting compatibilities compatibility = tf.nn.softmax(logits) common_attention.attention_image_summary( tf.expand_dims(compatibility, axis=1), None) # getting edge compatibilities ready to compute values. # after tiling, edge_compatibility will be # [batch, num_edge_types, length, length] edge_compatibility = tf.tile( tf.expand_dims(compatibility, axis=1), [1, num_edge_types, 1, 1]) # computing values edge_compatibility *= adjacency_matrix_one_hot # all edge values will be [batch, num_edge_types, length, depth] # We also assumed that the linear transformations for edge_type 0 will # all be zeros. That is [batch, 0] is a length*depth tensor of 0's all_edge_values = tf.matmul(edge_compatibility, v) # reducing along the num_edge_types dimension output = tf.reduce_sum(all_edge_values, axis=1) return output
def dot_product_mpnn_attention(q, k, v, adjacency_matrix, num_edge_types, num_transforms=None, use_weighted_sum=False, name=None): """Dot product attention with edge vectors. Let B be the number of batches. Let N be the number of nodes in the graph. Let K be the size of the attention keys/queries. Let V be the size of the attention values. Let T be the total number of transforms (num_transforms). Args: q: The query Tensor of shape [B, N, K]. k: The key Tensor of shape [B, T, N, K]. v: The value Tensor of shape [B, T, N, V]. adjacency_matrix: A Tensor of shape [B, N, N, T]. An entry at indices b, i, j, k is the indicator of the edge from node j to node i in batch b. A standard adjacency matrix will only have one edge type while a mutigraph will have multiple edge types. num_edge_types: An integer specifying number of edge types. num_transforms: An integer indicating number of transforms (T). If None, then num_transforms will be equal to num_edge_types. use_weighted_sum: If False, will only use a single transform per edge type. Otherwise, use a learned weighted sum of transforms per edge type. name: A string. Returns: A Tensor of shape [B, N, V] storing the result of computing attention weights using the queries and keys and combining the values according to those weights. Raises: ValueError: if num_transforms doesn't equal num_edge_types and not using weighted sum. """ with tf.variable_scope(name, default_name="dot_product_mpnn_attention", values=[q, k, v, adjacency_matrix, num_edge_types]): # If not explicitly set, use num_transforms set to num_edge_types. num_transforms = (num_edge_types if num_transforms is None else num_transforms) if not use_weighted_sum and num_transforms != num_edge_types: raise ValueError("num_transforms must equal num_edge_types unless " "use_weighted_sum is True") # Computes the raw dot-product attention values between each query and # the corresponding keys it needs to consider. # # This operation takes the dot product of (the query for # each node) and (the key for each node for each possible edge type), # creating an N x N matrix for each edge type. The entry at index (i, j) # is the dot-product for the edge from node i to node j of the appropriate # type. These dot products will eventually become attention weights # specifying how much node i weights an edge of that type coming from node # j. all_edge_logits = tf.matmul(tf.tile(tf.expand_dims(q, axis=1), [1, num_edge_types, 1, 1]), k, transpose_b=True) # The adjacency matrix assumes there is only one directed edge (i <- j) for # each pair of nodes. If such an edge exists, it contains the integer # type of that edge at position (i, j) of the adjacency matrix. # # Construct edge_vectors of shape [B, N, N, T]. if use_weighted_sum: # Use dense representation for edge vectors. edge_vectors = make_edge_vectors(adjacency_matrix, num_edge_types, num_transforms) else: # Generate one-hot vectors based on edge types. # If there is an edge from node j to node i of type t, then index t of the # last dimension is 1 for entry (i, j) of the second and third dimensions. edge_vectors = tf.one_hot(adjacency_matrix, num_transforms) # Rearranging the dimensions to match the shape of all_edge_logits. edge_vectors = tf.transpose(edge_vectors, [0, 3, 1, 2]) # Element-wise multiplies all_edge_logits and edge_vectors. # # In other words: all_edge_logits contains N x N matrices of query-key # products. This element-wise multiplication zeroes out entries that do not # correspond to actual edges in the graph of the appropriate edge type. # all_edge_logits retains shape [B, T, N, N]. all_edge_logits *= edge_vectors # Since there can only be one edge from node A to node B, we can collapse # the T different adjacency matrices containing key-query pairs into one # adjacency matrix. logits is [B, N, N]. # TODO(dbieber): Use a reshape instead of reduce sum to attend over all # edges instead of over all neighboring nodes to handle the multigraph case. logits = tf.reduce_sum(all_edge_logits, axis=1) # For pairs of nodes with no edges between them, add a large negative bias # to each location without an edge so that the softmax of entries with the # value 0 become a small negative number instead. bias = 0 bias = tf.to_float( tf.equal(tf.reduce_sum(adjacency_matrix, axis=-1), 0)) * -1e9 logits += bias # Turn the raw key-query products into a probability distribution (or, # in terms of attention, weights). The softmax is computed across the # last dimension of logits. compatibility = tf.nn.softmax(logits) # Shape [B, N, N]. # Computes a summary showing the attention matrix as an image. Does not do # any work toward actually performing attention. common_attention.attention_image_summary( tf.expand_dims(compatibility, axis=1), None) # Repeats the attention matrix T times for each batch, producing # a tensor with shape [B, T, N, N] where the [N, N] component is T # repeats of the values found in compatibility. edge_compatibility = tf.tile(tf.expand_dims(compatibility, axis=1), [1, num_edge_types, 1, 1]) # Zeroes out the entries in edge_compatibility that do not correspond to # actual edges. edge_compatibility *= edge_vectors # Shape [B, T, N, N]. output = compute_values(edge_compatibility, v) return output
def graph_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, make_image_summary=True, save_weights_to=None, dropout_broadcast_dims=None, adjacency_matrix=None, num_edge_types=5): """graph attention. Args: q: a Tensor with shape [batch, heads, length_q, depth_k] k: a Tensor with shape [batch, heads, length_kv, depth_k] v: a Tensor with shape [batch, heads, length_kv, depth_v] bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string make_image_summary: True if you want an image summary. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. adjacency_matrix: optional matrix of [batch, length, length] ids indicating edge type num_edge_types: an int indicating number of edge types Returns: A Tensor of shape [batch, length, depth(q)] """ with tf.variable_scope(name, default_name="dot_product_attention", values=[q, k, v]) as scope: # [batch, num_heads, query_length, memory_length] logits = tf.matmul(q, k, transpose_b=True) if adjacency_matrix is not None: key_head_depth = common_layers.shape_list(q)[-1] adjacency_vectors = make_edge_vectors(adjacency_matrix, num_edge_types, key_head_depth, name=name) # transposing q to be [batch, length_q, heads, depth_k] # to allow for matmul with [batch, length_q, length_q, depth_k] q_t = tf.transpose(q, [0, 2, 1, 3]) adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True) logits += tf.transpose(adj_logits, [0, 2, 1, 3]) # [batch, depth, num_nodes, num_nodes] if bias is not None: logits += bias weights = tf.nn.softmax(logits, name="attention_weights") if save_weights_to is not None: save_weights_to[scope.name] = weights # dropping out the attention links for each of the heads weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) if common_layers.should_generate_summaries() and make_image_summary: common_attention.attention_image_summary(weights, image_shapes) return tf.matmul(weights, v)
def dot_product_attention_mtsa( q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, make_image_summary=True, save_weights_to=None, dropout_broadcast_dims=None, use_k_mtsa=True, afn_extra='none', afn_dot='exp', afn_multi='exp', bias_start=0., bi_direction=False, ): """Dot-product attention. Args: q: Tensor with shape [..., length_q, depth_k]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. bias: bias Tensor (see attention_bias()) dropout_rate: a float. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string make_image_summary: True if you want an image summary. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than rank of q. Specifies in which dimensions to broadcast the dropout decisions. Returns: Tensor with shape [..., length_q, depth_v]. """ print("!!!!!dot_product_attention_mtsa!!!!!") with tf.variable_scope(name, default_name="dot_product_attention", values=[q, k, v]) as scope: # get dim dim_q = q.get_shape().as_list()[-1] dim_k = k.get_shape().as_list()[-1] dim_v = v.get_shape().as_list()[-1] # prepare multi_logits_scale_factor = 1. / math.sqrt( dim_v) if afn_multi.startswith('scaled') else 1. afn_extra, afn_dot, afn_multi = afn_name2fn(afn_extra), afn_name2fn( afn_dot), afn_name2fn(afn_multi) # if bias is not None: # inp_mask_1d = tf.to_float(tf.equal(bias, 0.)) # bs,1,1,vl # inp_mask_1d = tf.transpose(inp_mask_1d, [0, 1, 3, 2]) # bs,1,vl,1 # else: # inp_mask_1d = None # token2token self attention dot_logits = tf.matmul(q, k, transpose_b=True) # bs,hd,ql,vl if bias is not None: bias = common_layers.cast_like(bias, dot_logits) # 1/bs,1,ql/1,vl dot_logits += bias e_dot_logits = afn_dot(dot_logits) # bs,hd,ql,vl if bi_direction: head_num = v.get_shape().as_list()[1] ql, vl = tf.shape(q)[-2], tf.shape(v)[-2] assert head_num is not None assert head_num % 2 == 0 ones_mat = tf.ones([ql, vl], tf.float32) mul_mask_fw = tf.matrix_band_part(ones_mat, -1, 0) # Lower triangular part. mul_mask_bw = tf.matrix_band_part(ones_mat, 0, -1) # Upper triangular part. mul_mask_fw_tile = tf.tile(tf.expand_dims(mul_mask_fw, 0), [head_num // 2, 1, 1]) mul_mask_bw_tile = tf.tile(tf.expand_dims(mul_mask_bw, 0), [head_num // 2, 1, 1]) mul_mask = tf.expand_dims(tf.concat( [mul_mask_fw_tile, mul_mask_bw_tile], axis=0), axis=0) e_dot_logits *= mul_mask # source2token self-attention multi_logits = multi_head_dense_layer( k if use_k_mtsa else v, dim_v, True, bias_start if afn_extra is None else 0., 'multi_logits1') if afn_extra is not None: # use one extra layer for multi-dim multi_logits = multi_head_dense_layer(afn_extra(multi_logits), dim_v, True, bias_start, 'multi_logits2') e_multi_logits = afn_multi(multi_logits * multi_logits_scale_factor) # bs,hd,vl,vd # if inp_mask_1d is not None: # use mask for exp_logits # e_multi_logits *= inp_mask_1d # mtsa accum_z_deno = tf.matmul(e_dot_logits, e_multi_logits) # bs,hd,ql,vd accum_z_deno = tf.where( # in case of NaN and Inf tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)), accum_z_deno, tf.ones_like(accum_z_deno)) # attention dropout e_dot_logits = common_layers.dropout_with_broadcast_dims( e_dot_logits, math.sqrt(1. - dropout_rate), broadcast_dims=dropout_broadcast_dims) e_multi_logits = common_layers.dropout_with_broadcast_dims( e_multi_logits, math.sqrt(1. - dropout_rate), broadcast_dims=dropout_broadcast_dims) rep_mul_score = v * e_multi_logits # bs,hd,vl,vd accum_rep_mul_score = tf.matmul(e_dot_logits, rep_mul_score) # bs,hd,ql,vd # calculate the final attention results attn_res = accum_rep_mul_score / accum_z_deno # if inp_mask_1d is not None: # use mask for output # attn_res *= inp_mask_1d # ============ for vis ======= weights = e_dot_logits / (tf.reduce_sum( e_dot_logits, axis=-1, keepdims=True, name="attention_weights") + 0.00001) if save_weights_to is not None: save_weights_to[scope.name] = weights save_weights_to[scope.name + "/logits"] = dot_logits if common_layers.should_generate_summaries() and make_image_summary: common_attention.attention_image_summary(weights, image_shapes) return attn_res
def dot_product_mpnn_attention(q, k, v, adjacency_matrix, num_edge_types, ignore_zero=True, name=None): """Dot product attention with edge vectors. Let B be the number of batches. Let N be the number of nodes in the graph. Let K be the size of the attention keys/queries. Let V be the size of the attention values. Let T be the total number of edge types (num_edge_types). Args: q: The query Tensor of shape [B, N, K]. k: The key Tensor of shape [B, T, N, K]. v: The value Tensor of shape [B, T, N, V]. adjacency_matrix: A Tensor of shape [B, N, N]. An entry at indices b, i, j is the integer edge type of the edge from node j to node i in batch b. num_edge_types: An integer specifying number of edge types (T). ignore_zero: A flag that says that edge type 0 should be ignored. name: A string. Returns: A Tensor of shape [B, N, V] storing the result of computing attention weights using the queries and keys and combining the values according to those weights. """ # TODO(jfrankle): Consider ways to handle graphs that have multiple edges # between the same nodes (with only one edge of each type. adjacency_matrix # will need to be converted to shape [B, T, N, N]. with tf.variable_scope(name, default_name="dot_product_mpnn_attention", values=[q, k, v, adjacency_matrix, num_edge_types]): # Computes the raw dot-product attention values between each query and # the corresponding keys it needs to consider. # # This operation takes the dot product of (the query for # each node) and (the key for each node for each possible edge type), # creating an N x N matrix for each edge type. The entry at index (i, j) # is the dot-product for the edge from node i to node j of the appropriate # type. These dot products will eventually become attention weights # specifying how much node i weights an edge of that type coming from node # j. all_edge_logits = tf.matmul(tf.tile(tf.expand_dims(q, axis=1), [1, num_edge_types, 1, 1]), k, transpose_b=True) # The adjacency matrix assumes there is only one directed edge (i <- j) for # each pair of nodes. If such an edge exists, it contains the integer # type of that edge at position (i, j) of the adjacency matrix. # # adjacency_matrix_one_hot has shape [B, N, N, T]. If there is an edge # from node j to node i of type t, then index t of the last dimension is # 1 for entry (i, j) of the second and third dimensions. adjacency_matrix_one_hot = tf.one_hot(adjacency_matrix, num_edge_types) # Rearranging the dimensions to match the shape of all_edge_logits. adjacency_matrix_one_hot = tf.transpose(adjacency_matrix_one_hot, [0, 3, 1, 2]) # Element-wise multiplies all_edge_logits and adjacency_matrix_one_hot. # # In other words: all_edge_logits contains N x N matrices of query-key # products. This element-wise multiplication zeroes out entries that do not # correspond to actual edges in the graph of the appropriate edge type. # all_edge_logits retains shape [B, T, N, N]. all_edge_logits *= adjacency_matrix_one_hot # Since there can only be one edge from node A to node B, we can collapse # the T different adjacency matrices containing key-query pairs into one # adjacency matrix. logits is [B, N, N]. logits = tf.reduce_sum(all_edge_logits, axis=1) # If we do not have any special treatment for edge type 0, add a large, # negative bias to each location without an edge so that the softmax of # entries with the value 0 become a small negative number instead. # # TODO(avaswani): Better explanation of the rationale behind ignore_zero # here and throughout. bias = 0 if ignore_zero: bias = tf.to_float(tf.equal(adjacency_matrix, 0)) * -1e9 logits += bias # Turn the raw key-query products into a probability distribution (or, # in terms of attention, weights). The softmax is computed across the # last dimension of logits. compatibility = tf.nn.softmax(logits) # Shape [B, N, N]. # Computes a summary showing the attention matrix as an image. Does not do # any work toward actually performing attention. common_attention.attention_image_summary( tf.expand_dims(compatibility, axis=1), None) # Repeats the attention matrix T times for each batch, producing # a tensor with shape [B, T, N, N] where the [N, N] component is T # repeats of the values found in compatibility. edge_compatibility = tf.tile(tf.expand_dims(compatibility, axis=1), [1, num_edge_types, 1, 1]) # Zeroes out the entries in edge_compatibility that do not correspond to # actual edges. edge_compatibility *= adjacency_matrix_one_hot # Shape [B, T, N, N]. # Computes the incoming value vectors for each node by weighting them # according to the attention weights. These values are still segregated by # edge type. all_edge_values = tf.matmul(edge_compatibility, v) # Shape = [B, T, N, V]. # Combines the weighted value vectors together across edge types into a # single N x V matrix for each batch. output = tf.reduce_sum(all_edge_values, axis=1) # Shape [B, N, V]. return output
def dot_product_mpnn_attention(q, k, v, adjacency_matrix, num_edge_types, num_transforms=None, use_weighted_sum=False, name=None): """Dot product attention with edge vectors. Let B be the number of batches. Let N be the number of nodes in the graph. Let K be the size of the attention keys/queries. Let V be the size of the attention values. Let T be the total number of transforms (num_transforms). Args: q: The query Tensor of shape [B, N, K]. k: The key Tensor of shape [B, T, N, K]. v: The value Tensor of shape [B, T, N, V]. adjacency_matrix: A Tensor of shape [B, N, N, T]. An entry at indices b, i, j, k is the indicator of the edge from node j to node i in batch b. A standard adjacency matrix will only have one edge type while a mutigraph will have multiple edge types. num_edge_types: An integer specifying number of edge types. num_transforms: An integer indicating number of transforms (T). If None, then num_transforms will be equal to num_edge_types. use_weighted_sum: If False, will only use a single transform per edge type. Otherwise, use a learned weighted sum of transforms per edge type. name: A string. Returns: A Tensor of shape [B, N, V] storing the result of computing attention weights using the queries and keys and combining the values according to those weights. Raises: ValueError: if num_transforms doesn't equal num_edge_types and not using weighted sum. """ with tf.variable_scope( name, default_name="dot_product_mpnn_attention", values=[q, k, v, adjacency_matrix, num_edge_types]): # If not explicitly set, use num_transforms set to num_edge_types. num_transforms = ( num_edge_types if num_transforms is None else num_transforms) if not use_weighted_sum and num_transforms != num_edge_types: raise ValueError("num_transforms must equal num_edge_types unless " "use_weighted_sum is True") # Computes the raw dot-product attention values between each query and # the corresponding keys it needs to consider. # # This operation takes the dot product of (the query for # each node) and (the key for each node for each possible edge type), # creating an N x N matrix for each edge type. The entry at index (i, j) # is the dot-product for the edge from node i to node j of the appropriate # type. These dot products will eventually become attention weights # specifying how much node i weights an edge of that type coming from node # j. all_edge_logits = tf.matmul( tf.tile(tf.expand_dims(q, axis=1), [1, num_edge_types, 1, 1]), k, transpose_b=True) # The adjacency matrix assumes there is only one directed edge (i <- j) for # each pair of nodes. If such an edge exists, it contains the integer # type of that edge at position (i, j) of the adjacency matrix. # # Construct edge_vectors of shape [B, N, N, T]. if use_weighted_sum: # Use dense representation for edge vectors. edge_vectors = make_edge_vectors( adjacency_matrix, num_edge_types, num_transforms) else: # Generate one-hot vectors based on edge types. # If there is an edge from node j to node i of type t, then index t of the # last dimension is 1 for entry (i, j) of the second and third dimensions. edge_vectors = tf.one_hot(adjacency_matrix, num_transforms) # Rearranging the dimensions to match the shape of all_edge_logits. edge_vectors = tf.transpose(edge_vectors, [0, 3, 1, 2]) # Element-wise multiplies all_edge_logits and edge_vectors. # # In other words: all_edge_logits contains N x N matrices of query-key # products. This element-wise multiplication zeroes out entries that do not # correspond to actual edges in the graph of the appropriate edge type. # all_edge_logits retains shape [B, T, N, N]. all_edge_logits *= edge_vectors # Since there can only be one edge from node A to node B, we can collapse # the T different adjacency matrices containing key-query pairs into one # adjacency matrix. logits is [B, N, N]. # TODO(dbieber): Use a reshape instead of reduce sum to attend over all # edges instead of over all neighboring nodes to handle the multigraph case. logits = tf.reduce_sum(all_edge_logits, axis=1) # For pairs of nodes with no edges between them, add a large negative bias # to each location without an edge so that the softmax of entries with the # value 0 become a small negative number instead. bias = 0 bias = tf.to_float(tf.equal( tf.reduce_sum(adjacency_matrix, axis=-1), 0)) * -1e9 logits += bias # Turn the raw key-query products into a probability distribution (or, # in terms of attention, weights). The softmax is computed across the # last dimension of logits. compatibility = tf.nn.softmax(logits) # Shape [B, N, N]. # Computes a summary showing the attention matrix as an image. Does not do # any work toward actually performing attention. common_attention.attention_image_summary( tf.expand_dims(compatibility, axis=1), None) # Repeats the attention matrix T times for each batch, producing # a tensor with shape [B, T, N, N] where the [N, N] component is T # repeats of the values found in compatibility. edge_compatibility = tf.tile( tf.expand_dims(compatibility, axis=1), [1, num_edge_types, 1, 1]) # Zeroes out the entries in edge_compatibility that do not correspond to # actual edges. edge_compatibility *= edge_vectors # Shape [B, T, N, N]. output = compute_values(edge_compatibility, v) return output
def graph_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, make_image_summary=True, save_weights_to=None, dropout_broadcast_dims=None, adjacency_matrix=None, num_edge_types=5): """graph attention. Args: q: a Tensor with shape [batch, heads, length_q, depth_k] k: a Tensor with shape [batch, heads, length_kv, depth_k] v: a Tensor with shape [batch, heads, length_kv, depth_v] bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string make_image_summary: True if you want an image summary. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. adjacency_matrix: optional matrix of [batch, length, length] ids indicating edge type num_edge_types: an int indicating number of edge types Returns: A Tensor of shape [batch, length, depth(q)] """ with tf.variable_scope( name, default_name="dot_product_attention", values=[q, k, v]) as scope: # [batch, num_heads, query_length, memory_length] logits = tf.matmul(q, k, transpose_b=True) if adjacency_matrix is not None: key_head_depth = common_layers.shape_list(q)[-1] adjacency_vectors = make_edge_vectors( adjacency_matrix, num_edge_types, key_head_depth, name=name) # transposing q to be [batch, length_q, heads, depth_k] # to allow for matmul with [batch, length_q, length_q, depth_k] q_t = tf.transpose(q, [0, 2, 1, 3]) adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True) logits += tf.transpose(adj_logits, [0, 2, 1, 3]) # [batch, depth, num_nodes, num_nodes] if bias is not None: logits += bias weights = tf.nn.softmax(logits, name="attention_weights") if save_weights_to is not None: save_weights_to[scope.name] = weights # dropping out the attention links for each of the heads weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) if common_layers.should_generate_summaries() and make_image_summary: common_attention.attention_image_summary(weights, image_shapes) return tf.matmul(weights, v)