def dot_product_highorder_attention(q,
                                    k,
                                    v,
                                    bias,
                                    dropout_rate=0.0,
                                    image_shapes=None,
                                    attention_order=2,
                                    name=None,
                                    make_image_summary=True):
    """High order dot-product attention. Attention is applied repeatedly
  to generate query vectors. For example, 2-order attention uses q,k,v
  to generate a new query vector q'. The final attention result is
  computed with q',k,v.

  Args:
    q: a Tensor with shape [batch, heads, length_q, depth_k]
    k: a Tensor with shape [batch, heads, length_kv, depth_k]
    v: a Tensor with shape [batch, heads, length_kv, depth_v]
    bias: bias Tensor (see attention_bias())
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    attention_order (int): Attention order (number of steps)
    name: an optional string
    make_image_summary: True if you want an image summary.

  Returns:
    A Tensor.
  """
    if attention_order == 1:
        return common_attention.dot_product_attention(
            q,
            k,
            v,
            bias,
            dropout_rate,
            image_shapes,
            name=name,
            make_image_summary=make_image_summary)
    # Split q, k in attention_order pieces
    qs = tf.split(q, attention_order, axis=3)
    ks = tf.split(k, attention_order, axis=3)
    with tf.variable_scope(name,
                           default_name="dot_product_highorder_attention",
                           values=[q, k, v]):
        for idx in xrange(attention_order):
            # [batch, num_heads, query_length, memory_length]
            q = tf.matmul(weights, qs[idx]) if idx != 0 else qs[0]
            logits = tf.matmul(q, ks[idx], transpose_b=True)
            if bias is not None:
                logits += bias
            weights = tf.nn.softmax(logits, name="attention_weights")
        # dropping out the attention links for each of the heads
        weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
        if (not tf.get_variable_scope().reuse and
                # Summaries don't work well within tf.while_loop()
                "/while/" not in tf.contrib.framework.get_name_scope()
                and make_image_summary):
            common_attention.attention_image_summary(weights, image_shapes)
        return tf.matmul(weights, v)
def dot_product_mpnn_attention(q, k, v, adjacency_matrix, num_edge_types,
                               ignore_zero=True, name=None):
  """Dot product attention with edge vectors.

  Args:
    q: [batch, length, key_depth] tensor
    k: [batch, num_edge_types, length, key_depth]
    v: [batch, num_edge_types, length, depth]
    adjacency_matrix: [batch, length, length] tensor of int edge types
    num_edge_types: an int, specifying number of edge types
    ignore_zero: A flag that says that edge type 0 should be ignored
    name: optional string

  Returns:
    A tensor of shape [batch, length, depth(q)]
  """
  with tf.variable_scope(
      name, default_name="dot_product_mpnn_attention",
      values=[q, k, v, adjacency_matrix, num_edge_types]):
    # Computing attention mask
    # all edge logits will have shape [batch, edge_types, len, len]
    all_edge_logits = tf.matmul(
        tf.tile(tf.expand_dims(q, axis=1), [1, num_edge_types, 1, 1]),
        k, transpose_b=True)
    # adjacency_matrix_one_hot has shape [batch, len, len, num_edge_types]
    adjacency_matrix_one_hot = tf.one_hot(adjacency_matrix, num_edge_types)
    # making adjacency_matrix_one_hot [batch, edge_types, len, len]
    adjacency_matrix_one_hot = tf.transpose(adjacency_matrix_one_hot,
                                            [0, 3, 1, 2])
    # getting dot products for q_i, k_j, and e_{ij}. This assumes that for
    # edge type 0, the dot products are 0
    all_edge_logits *= adjacency_matrix_one_hot
    # logits will be [batch, length, length] after reducing along
    # axis 1 which has dimension num_edge_types.
    logits = tf.reduce_sum(all_edge_logits, axis=1)
    # ignoring edges if needed
    bias = 0
    if ignore_zero:
      bias = tf.to_float(tf.equal(adjacency_matrix, 0)) * -1e9
    logits += bias
    # getting compatibilities
    compatibility = tf.nn.softmax(logits)
    common_attention.attention_image_summary(
        tf.expand_dims(compatibility, axis=1), None)
    # getting edge compatibilities ready to compute values.
    # after tiling, edge_compatibility will be
    # [batch, num_edge_types, length, length]
    edge_compatibility = tf.tile(
        tf.expand_dims(compatibility, axis=1), [1, num_edge_types, 1, 1])
    # computing values
    edge_compatibility *= adjacency_matrix_one_hot
    # all edge values will be [batch, num_edge_types, length, depth]
    # We also assumed that the linear transformations for edge_type 0 will
    # all be zeros. That is [batch, 0] is a length*depth tensor of 0's
    all_edge_values = tf.matmul(edge_compatibility, v)
    # reducing along the num_edge_types dimension
    output = tf.reduce_sum(all_edge_values, axis=1)
    return output
Beispiel #3
0
def dot_product_mpnn_attention(q,
                               k,
                               v,
                               adjacency_matrix,
                               num_edge_types,
                               num_transforms=None,
                               use_weighted_sum=False,
                               name=None):
    """Dot product attention with edge vectors.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let K be the size of the attention keys/queries.
  Let V be the size of the attention values.
  Let T be the total number of transforms (num_transforms).

  Args:
    q: The query Tensor of shape [B, N, K].
    k: The key Tensor of shape [B, T, N, K].
    v: The value Tensor of shape [B, T, N, V].
    adjacency_matrix: A Tensor of shape [B, N, N, T]. An entry at
      indices b, i, j, k is the indicator of the edge
      from node j to node i in batch b. A standard adjacency matrix will only
      have one edge type while a mutigraph will have multiple edge types.
    num_edge_types: An integer specifying number of edge types.
    num_transforms: An integer indicating number of transforms (T). If None,
      then num_transforms will be equal to num_edge_types.
    use_weighted_sum: If False, will only use a single transform per edge type.
      Otherwise, use a learned weighted sum of transforms per edge type.
    name: A string.

  Returns:
    A Tensor of shape [B, N, V] storing the result of computing attention
    weights using the queries and keys and combining the values according to
    those weights.

  Raises:
    ValueError: if num_transforms doesn't equal num_edge_types and not using
      weighted sum.
  """
    with tf.variable_scope(name,
                           default_name="dot_product_mpnn_attention",
                           values=[q, k, v, adjacency_matrix, num_edge_types]):
        # If not explicitly set, use num_transforms set to num_edge_types.
        num_transforms = (num_edge_types
                          if num_transforms is None else num_transforms)

        if not use_weighted_sum and num_transforms != num_edge_types:
            raise ValueError("num_transforms must equal num_edge_types unless "
                             "use_weighted_sum is True")

        # Computes the raw dot-product attention values between each query and
        # the corresponding keys it needs to consider.
        #
        # This operation takes the dot product of (the query for
        # each node) and (the key for each node for each possible edge type),
        # creating an N x N matrix for each edge type. The entry at index (i, j)
        # is the dot-product for the edge from node i to node j of the appropriate
        # type. These dot products will eventually become attention weights
        # specifying how much node i weights an edge of that type coming from node
        # j.
        all_edge_logits = tf.matmul(tf.tile(tf.expand_dims(q, axis=1),
                                            [1, num_edge_types, 1, 1]),
                                    k,
                                    transpose_b=True)

        # The adjacency matrix assumes there is only one directed edge (i <- j) for
        # each pair of nodes. If such an edge exists, it contains the integer
        # type of that edge at position (i, j) of the adjacency matrix.
        #
        # Construct edge_vectors of shape [B, N, N, T].
        if use_weighted_sum:
            # Use dense representation for edge vectors.
            edge_vectors = make_edge_vectors(adjacency_matrix, num_edge_types,
                                             num_transforms)
        else:
            # Generate one-hot vectors based on edge types.
            # If there is an edge from node j to node i of type t, then index t of the
            # last dimension is 1 for entry (i, j) of the second and third dimensions.
            edge_vectors = tf.one_hot(adjacency_matrix, num_transforms)

        # Rearranging the dimensions to match the shape of all_edge_logits.
        edge_vectors = tf.transpose(edge_vectors, [0, 3, 1, 2])

        # Element-wise multiplies all_edge_logits and edge_vectors.
        #
        # In other words: all_edge_logits contains N x N matrices of query-key
        # products. This element-wise multiplication zeroes out entries that do not
        # correspond to actual edges in the graph of the appropriate edge type.
        # all_edge_logits retains shape [B, T, N, N].
        all_edge_logits *= edge_vectors

        # Since there can only be one edge from node A to node B, we can collapse
        # the T different adjacency matrices containing key-query pairs into one
        # adjacency matrix. logits is [B, N, N].
        # TODO(dbieber): Use a reshape instead of reduce sum to attend over all
        # edges instead of over all neighboring nodes to handle the multigraph case.
        logits = tf.reduce_sum(all_edge_logits, axis=1)

        # For pairs of nodes with no edges between them, add a large negative bias
        # to each location without an edge so that the softmax of entries with the
        # value 0 become a small negative number instead.
        bias = 0
        bias = tf.to_float(
            tf.equal(tf.reduce_sum(adjacency_matrix, axis=-1), 0)) * -1e9
        logits += bias

        # Turn the raw key-query products into a probability distribution (or,
        # in terms of attention, weights). The softmax is computed across the
        # last dimension of logits.
        compatibility = tf.nn.softmax(logits)  # Shape [B, N, N].

        # Computes a summary showing the attention matrix as an image. Does not do
        # any work toward actually performing attention.
        common_attention.attention_image_summary(
            tf.expand_dims(compatibility, axis=1), None)

        # Repeats the attention matrix T times for each batch, producing
        # a tensor with shape [B, T, N, N] where the [N, N] component is T
        # repeats of the values found in compatibility.
        edge_compatibility = tf.tile(tf.expand_dims(compatibility, axis=1),
                                     [1, num_edge_types, 1, 1])

        # Zeroes out the entries in edge_compatibility that do not correspond to
        # actual edges.
        edge_compatibility *= edge_vectors  # Shape [B, T, N, N].

        output = compute_values(edge_compatibility, v)
        return output
Beispiel #4
0
def graph_attention(q,
                    k,
                    v,
                    bias,
                    dropout_rate=0.0,
                    image_shapes=None,
                    name=None,
                    make_image_summary=True,
                    save_weights_to=None,
                    dropout_broadcast_dims=None,
                    adjacency_matrix=None,
                    num_edge_types=5):
    """graph attention.

  Args:
    q: a Tensor with shape [batch, heads, length_q, depth_k]
    k: a Tensor with shape [batch, heads, length_kv, depth_k]
    v: a Tensor with shape [batch, heads, length_kv, depth_v]
    bias: bias Tensor (see attention_bias())
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    make_image_summary: True if you want an image summary.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    adjacency_matrix: optional matrix of [batch, length, length] ids indicating
      edge type
    num_edge_types: an int indicating number of edge types
  Returns:
    A Tensor of shape [batch, length, depth(q)]
  """
    with tf.variable_scope(name,
                           default_name="dot_product_attention",
                           values=[q, k, v]) as scope:
        # [batch, num_heads, query_length, memory_length]
        logits = tf.matmul(q, k, transpose_b=True)
        if adjacency_matrix is not None:
            key_head_depth = common_layers.shape_list(q)[-1]
            adjacency_vectors = make_edge_vectors(adjacency_matrix,
                                                  num_edge_types,
                                                  key_head_depth,
                                                  name=name)
            # transposing q to be [batch, length_q, heads, depth_k]
            # to allow for matmul with [batch, length_q, length_q, depth_k]
            q_t = tf.transpose(q, [0, 2, 1, 3])
            adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True)
            logits += tf.transpose(adj_logits, [0, 2, 1, 3])
            # [batch, depth, num_nodes, num_nodes]
        if bias is not None:
            logits += bias
        weights = tf.nn.softmax(logits, name="attention_weights")
        if save_weights_to is not None:
            save_weights_to[scope.name] = weights
        # dropping out the attention links for each of the heads
        weights = common_layers.dropout_with_broadcast_dims(
            weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
        if common_layers.should_generate_summaries() and make_image_summary:
            common_attention.attention_image_summary(weights, image_shapes)
        return tf.matmul(weights, v)
Beispiel #5
0
def dot_product_attention_mtsa(
    q,
    k,
    v,
    bias,
    dropout_rate=0.0,
    image_shapes=None,
    name=None,
    make_image_summary=True,
    save_weights_to=None,
    dropout_broadcast_dims=None,
    use_k_mtsa=True,
    afn_extra='none',
    afn_dot='exp',
    afn_multi='exp',
    bias_start=0.,
    bi_direction=False,
):
    """Dot-product attention.

  Args:
    q: Tensor with shape [..., length_q, depth_k].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    bias: bias Tensor (see attention_bias())
    dropout_rate: a float.
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    make_image_summary: True if you want an image summary.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims: an optional list of integers less than rank of q.
      Specifies in which dimensions to broadcast the dropout decisions.

  Returns:
    Tensor with shape [..., length_q, depth_v].
  """
    print("!!!!!dot_product_attention_mtsa!!!!!")
    with tf.variable_scope(name,
                           default_name="dot_product_attention",
                           values=[q, k, v]) as scope:
        # get dim
        dim_q = q.get_shape().as_list()[-1]
        dim_k = k.get_shape().as_list()[-1]
        dim_v = v.get_shape().as_list()[-1]
        # prepare
        multi_logits_scale_factor = 1. / math.sqrt(
            dim_v) if afn_multi.startswith('scaled') else 1.
        afn_extra, afn_dot, afn_multi = afn_name2fn(afn_extra), afn_name2fn(
            afn_dot), afn_name2fn(afn_multi)
        # if bias is not None:
        #   inp_mask_1d = tf.to_float(tf.equal(bias, 0.))  # bs,1,1,vl
        #   inp_mask_1d = tf.transpose(inp_mask_1d, [0, 1, 3, 2])   # bs,1,vl,1
        # else:
        #   inp_mask_1d = None

        # token2token self attention
        dot_logits = tf.matmul(q, k, transpose_b=True)  # bs,hd,ql,vl
        if bias is not None:
            bias = common_layers.cast_like(bias, dot_logits)  # 1/bs,1,ql/1,vl
            dot_logits += bias
        e_dot_logits = afn_dot(dot_logits)  # bs,hd,ql,vl
        if bi_direction:
            head_num = v.get_shape().as_list()[1]
            ql, vl = tf.shape(q)[-2], tf.shape(v)[-2]
            assert head_num is not None
            assert head_num % 2 == 0
            ones_mat = tf.ones([ql, vl], tf.float32)
            mul_mask_fw = tf.matrix_band_part(ones_mat, -1,
                                              0)  #  Lower triangular part.
            mul_mask_bw = tf.matrix_band_part(ones_mat, 0,
                                              -1)  #  Upper triangular part.
            mul_mask_fw_tile = tf.tile(tf.expand_dims(mul_mask_fw, 0),
                                       [head_num // 2, 1, 1])
            mul_mask_bw_tile = tf.tile(tf.expand_dims(mul_mask_bw, 0),
                                       [head_num // 2, 1, 1])
            mul_mask = tf.expand_dims(tf.concat(
                [mul_mask_fw_tile, mul_mask_bw_tile], axis=0),
                                      axis=0)
            e_dot_logits *= mul_mask

        # source2token self-attention
        multi_logits = multi_head_dense_layer(
            k if use_k_mtsa else v, dim_v, True,
            bias_start if afn_extra is None else 0., 'multi_logits1')
        if afn_extra is not None:  # use one extra layer for multi-dim
            multi_logits = multi_head_dense_layer(afn_extra(multi_logits),
                                                  dim_v, True, bias_start,
                                                  'multi_logits2')
        e_multi_logits = afn_multi(multi_logits *
                                   multi_logits_scale_factor)  # bs,hd,vl,vd
        # if inp_mask_1d is not None:  # use mask for exp_logits
        #   e_multi_logits *= inp_mask_1d

        # mtsa
        accum_z_deno = tf.matmul(e_dot_logits, e_multi_logits)  # bs,hd,ql,vd
        accum_z_deno = tf.where(  # in case of NaN and Inf
            tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)),
            accum_z_deno, tf.ones_like(accum_z_deno))

        # attention dropout
        e_dot_logits = common_layers.dropout_with_broadcast_dims(
            e_dot_logits,
            math.sqrt(1. - dropout_rate),
            broadcast_dims=dropout_broadcast_dims)
        e_multi_logits = common_layers.dropout_with_broadcast_dims(
            e_multi_logits,
            math.sqrt(1. - dropout_rate),
            broadcast_dims=dropout_broadcast_dims)
        rep_mul_score = v * e_multi_logits  # bs,hd,vl,vd
        accum_rep_mul_score = tf.matmul(e_dot_logits,
                                        rep_mul_score)  # bs,hd,ql,vd
        # calculate the final attention results
        attn_res = accum_rep_mul_score / accum_z_deno
        # if inp_mask_1d is not None:  # use mask for output
        #   attn_res *= inp_mask_1d

        # ============ for vis =======
        weights = e_dot_logits / (tf.reduce_sum(
            e_dot_logits, axis=-1, keepdims=True, name="attention_weights") +
                                  0.00001)
        if save_weights_to is not None:
            save_weights_to[scope.name] = weights
            save_weights_to[scope.name + "/logits"] = dot_logits
        if common_layers.should_generate_summaries() and make_image_summary:
            common_attention.attention_image_summary(weights, image_shapes)
        return attn_res
def dot_product_mpnn_attention(q,
                               k,
                               v,
                               adjacency_matrix,
                               num_edge_types,
                               ignore_zero=True,
                               name=None):
    """Dot product attention with edge vectors.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let K be the size of the attention keys/queries.
  Let V be the size of the attention values.
  Let T be the total number of edge types (num_edge_types).

  Args:
    q: The query Tensor of shape [B, N, K].
    k: The key Tensor of shape [B, T, N, K].
    v: The value Tensor of shape [B, T, N, V].
    adjacency_matrix: A Tensor of shape [B, N, N]. An entry at indices b, i, j
     is the integer edge type of the edge from node j to node i in batch b.
    num_edge_types: An integer specifying number of edge types (T).
    ignore_zero: A flag that says that edge type 0 should be ignored.
    name: A string.

  Returns:
    A Tensor of shape [B, N, V] storing the result of computing attention
    weights using the queries and keys and combining the values according to
    those weights.
  """
    # TODO(jfrankle): Consider ways to handle graphs that have multiple edges
    # between the same nodes (with only one edge of each type. adjacency_matrix
    # will need to be converted to shape [B, T, N, N].
    with tf.variable_scope(name,
                           default_name="dot_product_mpnn_attention",
                           values=[q, k, v, adjacency_matrix, num_edge_types]):
        # Computes the raw dot-product attention values between each query and
        # the corresponding keys it needs to consider.
        #
        # This operation takes the dot product of (the query for
        # each node) and (the key for each node for each possible edge type),
        # creating an N x N matrix for each edge type. The entry at index (i, j)
        # is the dot-product for the edge from node i to node j of the appropriate
        # type. These dot products will eventually become attention weights
        # specifying how much node i weights an edge of that type coming from node
        # j.
        all_edge_logits = tf.matmul(tf.tile(tf.expand_dims(q, axis=1),
                                            [1, num_edge_types, 1, 1]),
                                    k,
                                    transpose_b=True)

        # The adjacency matrix assumes there is only one directed edge (i <- j) for
        # each pair of nodes. If such an edge exists, it contains the integer
        # type of that edge at position (i, j) of the adjacency matrix.
        #
        # adjacency_matrix_one_hot has shape [B, N, N, T]. If there is an edge
        # from node j to node i of type t, then index t of the last dimension is
        # 1 for entry (i, j) of the second and third dimensions.
        adjacency_matrix_one_hot = tf.one_hot(adjacency_matrix, num_edge_types)

        # Rearranging the dimensions to match the shape of all_edge_logits.
        adjacency_matrix_one_hot = tf.transpose(adjacency_matrix_one_hot,
                                                [0, 3, 1, 2])

        # Element-wise multiplies all_edge_logits and adjacency_matrix_one_hot.
        #
        # In other words: all_edge_logits contains N x N matrices of query-key
        # products. This element-wise multiplication zeroes out entries that do not
        # correspond to actual edges in the graph of the appropriate edge type.
        # all_edge_logits retains shape [B, T, N, N].
        all_edge_logits *= adjacency_matrix_one_hot

        # Since there can only be one edge from node A to node B, we can collapse
        # the T different adjacency matrices containing key-query pairs into one
        # adjacency matrix. logits is [B, N, N].
        logits = tf.reduce_sum(all_edge_logits, axis=1)

        # If we do not have any special treatment for edge type 0, add a large,
        # negative bias to each location without an edge so that the softmax of
        # entries with the value 0 become a small negative number instead.
        #
        # TODO(avaswani): Better explanation of the rationale behind ignore_zero
        # here and throughout.
        bias = 0
        if ignore_zero:
            bias = tf.to_float(tf.equal(adjacency_matrix, 0)) * -1e9
        logits += bias

        # Turn the raw key-query products into a probability distribution (or,
        # in terms of attention, weights). The softmax is computed across the
        # last dimension of logits.
        compatibility = tf.nn.softmax(logits)  # Shape [B, N, N].

        # Computes a summary showing the attention matrix as an image. Does not do
        # any work toward actually performing attention.
        common_attention.attention_image_summary(
            tf.expand_dims(compatibility, axis=1), None)

        # Repeats the attention matrix T times for each batch, producing
        # a tensor with shape [B, T, N, N] where the [N, N] component is T
        # repeats of the values found in compatibility.
        edge_compatibility = tf.tile(tf.expand_dims(compatibility, axis=1),
                                     [1, num_edge_types, 1, 1])

        # Zeroes out the entries in edge_compatibility that do not correspond to
        # actual edges.
        edge_compatibility *= adjacency_matrix_one_hot  # Shape [B, T, N, N].

        # Computes the incoming value vectors for each node by weighting them
        # according to the attention weights. These values are still segregated by
        # edge type.
        all_edge_values = tf.matmul(edge_compatibility,
                                    v)  # Shape = [B, T, N, V].

        # Combines the weighted value vectors together across edge types into a
        # single N x V matrix for each batch.
        output = tf.reduce_sum(all_edge_values, axis=1)  # Shape [B, N, V].

        return output
def dot_product_mpnn_attention(q,
                               k,
                               v,
                               adjacency_matrix,
                               num_edge_types,
                               num_transforms=None,
                               use_weighted_sum=False,
                               name=None):
  """Dot product attention with edge vectors.

  Let B be the number of batches.
  Let N be the number of nodes in the graph.
  Let K be the size of the attention keys/queries.
  Let V be the size of the attention values.
  Let T be the total number of transforms (num_transforms).

  Args:
    q: The query Tensor of shape [B, N, K].
    k: The key Tensor of shape [B, T, N, K].
    v: The value Tensor of shape [B, T, N, V].
    adjacency_matrix: A Tensor of shape [B, N, N, T]. An entry at
      indices b, i, j, k is the indicator of the edge
      from node j to node i in batch b. A standard adjacency matrix will only
      have one edge type while a mutigraph will have multiple edge types.
    num_edge_types: An integer specifying number of edge types.
    num_transforms: An integer indicating number of transforms (T). If None,
      then num_transforms will be equal to num_edge_types.
    use_weighted_sum: If False, will only use a single transform per edge type.
      Otherwise, use a learned weighted sum of transforms per edge type.
    name: A string.

  Returns:
    A Tensor of shape [B, N, V] storing the result of computing attention
    weights using the queries and keys and combining the values according to
    those weights.

  Raises:
    ValueError: if num_transforms doesn't equal num_edge_types and not using
      weighted sum.
  """
  with tf.variable_scope(
      name,
      default_name="dot_product_mpnn_attention",
      values=[q, k, v, adjacency_matrix, num_edge_types]):
    # If not explicitly set, use num_transforms set to num_edge_types.
    num_transforms = (
        num_edge_types if num_transforms is None else num_transforms)

    if not use_weighted_sum and num_transforms != num_edge_types:
      raise ValueError("num_transforms must equal num_edge_types unless "
                       "use_weighted_sum is True")

    # Computes the raw dot-product attention values between each query and
    # the corresponding keys it needs to consider.
    #
    # This operation takes the dot product of (the query for
    # each node) and (the key for each node for each possible edge type),
    # creating an N x N matrix for each edge type. The entry at index (i, j)
    # is the dot-product for the edge from node i to node j of the appropriate
    # type. These dot products will eventually become attention weights
    # specifying how much node i weights an edge of that type coming from node
    # j.
    all_edge_logits = tf.matmul(
        tf.tile(tf.expand_dims(q, axis=1), [1, num_edge_types, 1, 1]),
        k,
        transpose_b=True)

    # The adjacency matrix assumes there is only one directed edge (i <- j) for
    # each pair of nodes. If such an edge exists, it contains the integer
    # type of that edge at position (i, j) of the adjacency matrix.
    #
    # Construct edge_vectors of shape [B, N, N, T].
    if use_weighted_sum:
      # Use dense representation for edge vectors.
      edge_vectors = make_edge_vectors(
          adjacency_matrix,
          num_edge_types,
          num_transforms)
    else:
      # Generate one-hot vectors based on edge types.
      # If there is an edge from node j to node i of type t, then index t of the
      # last dimension is 1 for entry (i, j) of the second and third dimensions.
      edge_vectors = tf.one_hot(adjacency_matrix, num_transforms)

    # Rearranging the dimensions to match the shape of all_edge_logits.
    edge_vectors = tf.transpose(edge_vectors, [0, 3, 1, 2])

    # Element-wise multiplies all_edge_logits and edge_vectors.
    #
    # In other words: all_edge_logits contains N x N matrices of query-key
    # products. This element-wise multiplication zeroes out entries that do not
    # correspond to actual edges in the graph of the appropriate edge type.
    # all_edge_logits retains shape [B, T, N, N].
    all_edge_logits *= edge_vectors

    # Since there can only be one edge from node A to node B, we can collapse
    # the T different adjacency matrices containing key-query pairs into one
    # adjacency matrix. logits is [B, N, N].
    # TODO(dbieber): Use a reshape instead of reduce sum to attend over all
    # edges instead of over all neighboring nodes to handle the multigraph case.
    logits = tf.reduce_sum(all_edge_logits, axis=1)

    # For pairs of nodes with no edges between them, add a large negative bias
    # to each location without an edge so that the softmax of entries with the
    # value 0 become a small negative number instead.
    bias = 0
    bias = tf.to_float(tf.equal(
        tf.reduce_sum(adjacency_matrix, axis=-1), 0)) * -1e9
    logits += bias

    # Turn the raw key-query products into a probability distribution (or,
    # in terms of attention, weights). The softmax is computed across the
    # last dimension of logits.
    compatibility = tf.nn.softmax(logits)  # Shape [B, N, N].

    # Computes a summary showing the attention matrix as an image. Does not do
    # any work toward actually performing attention.
    common_attention.attention_image_summary(
        tf.expand_dims(compatibility, axis=1), None)

    # Repeats the attention matrix T times for each batch, producing
    # a tensor with shape [B, T, N, N] where the [N, N] component is T
    # repeats of the values found in compatibility.
    edge_compatibility = tf.tile(
        tf.expand_dims(compatibility, axis=1), [1, num_edge_types, 1, 1])

    # Zeroes out the entries in edge_compatibility that do not correspond to
    # actual edges.
    edge_compatibility *= edge_vectors  # Shape [B, T, N, N].

    output = compute_values(edge_compatibility, v)
    return output
def graph_attention(q,
                    k,
                    v,
                    bias,
                    dropout_rate=0.0,
                    image_shapes=None,
                    name=None,
                    make_image_summary=True,
                    save_weights_to=None,
                    dropout_broadcast_dims=None,
                    adjacency_matrix=None,
                    num_edge_types=5):
  """graph attention.

  Args:
    q: a Tensor with shape [batch, heads, length_q, depth_k]
    k: a Tensor with shape [batch, heads, length_kv, depth_k]
    v: a Tensor with shape [batch, heads, length_kv, depth_v]
    bias: bias Tensor (see attention_bias())
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    make_image_summary: True if you want an image summary.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    adjacency_matrix: optional matrix of [batch, length, length] ids indicating
      edge type
    num_edge_types: an int indicating number of edge types
  Returns:
    A Tensor of shape [batch, length, depth(q)]
  """
  with tf.variable_scope(
      name, default_name="dot_product_attention", values=[q, k, v]) as scope:
    # [batch, num_heads, query_length, memory_length]
    logits = tf.matmul(q, k, transpose_b=True)
    if adjacency_matrix is not None:
      key_head_depth = common_layers.shape_list(q)[-1]
      adjacency_vectors = make_edge_vectors(
          adjacency_matrix,
          num_edge_types,
          key_head_depth,
          name=name)
      # transposing q to be [batch, length_q, heads, depth_k]
      # to allow for matmul with [batch, length_q, length_q, depth_k]
      q_t = tf.transpose(q, [0, 2, 1, 3])
      adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True)
      logits += tf.transpose(adj_logits, [0, 2, 1, 3])
      # [batch, depth, num_nodes, num_nodes]
    if bias is not None:
      logits += bias
    weights = tf.nn.softmax(logits, name="attention_weights")
    if save_weights_to is not None:
      save_weights_to[scope.name] = weights
    # dropping out the attention links for each of the heads
    weights = common_layers.dropout_with_broadcast_dims(
        weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
    if common_layers.should_generate_summaries() and make_image_summary:
      common_attention.attention_image_summary(weights, image_shapes)
    return tf.matmul(weights, v)