def dot_product_highorder_attention(q,
    """High order dot-product attention. Attention is applied repeatedly
  to generate query vectors. For example, 2-order attention uses q,k,v
  to generate a new query vector q'. The final attention result is
  computed with q',k,v.

    q: a Tensor with shape [batch, heads, length_q, depth_k]
    k: a Tensor with shape [batch, heads, length_kv, depth_k]
    v: a Tensor with shape [batch, heads, length_kv, depth_v]
    bias: bias Tensor (see attention_bias())
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    attention_order (int): Attention order (number of steps)
    name: an optional string
    make_image_summary: True if you want an image summary.

    A Tensor.
    if attention_order == 1:
        return common_attention.dot_product_attention(
    # Split q, k in attention_order pieces
    qs = tf.split(q, attention_order, axis=3)
    ks = tf.split(k, attention_order, axis=3)
    with tf.variable_scope(name,
                           values=[q, k, v]):
        for idx in xrange(attention_order):
            # [batch, num_heads, query_length, memory_length]
            q = tf.matmul(weights, qs[idx]) if idx != 0 else qs[0]
            logits = tf.matmul(q, ks[idx], transpose_b=True)
            if bias is not None:
                logits += bias
            weights = tf.nn.softmax(logits, name="attention_weights")
        # dropping out the attention links for each of the heads
        weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
        if (not tf.get_variable_scope().reuse and
                # Summaries don't work well within tf.while_loop()
                "/while/" not in tf.contrib.framework.get_name_scope()
                and make_image_summary):
            common_attention.attention_image_summary(weights, image_shapes)
        return tf.matmul(weights, v)
def dot_product_mpnn_attention(q, k, v, adjacency_matrix, num_edge_types,
                               ignore_zero=True, name=None):
  """Dot product attention with edge vectors.

    q: [batch, length, key_depth] tensor
    k: [batch, num_edge_types, length, key_depth]
    v: [batch, num_edge_types, length, depth]
    adjacency_matrix: [batch, length, length] tensor of int edge types
    num_edge_types: an int, specifying number of edge types
    ignore_zero: A flag that says that edge type 0 should be ignored
    name: optional string

    A tensor of shape [batch, length, depth(q)]
  with tf.variable_scope(
      name, default_name="dot_product_mpnn_attention",
      values=[q, k, v, adjacency_matrix, num_edge_types]):
    # Computing attention mask
    # all edge logits will have shape [batch, edge_types, len, len]
    all_edge_logits = tf.matmul(
        tf.tile(tf.expand_dims(q, axis=1), [1, num_edge_types, 1, 1]),
        k, transpose_b=True)
    # adjacency_matrix_one_hot has shape [batch, len, len, num_edge_types]
    adjacency_matrix_one_hot = tf.one_hot(adjacency_matrix, num_edge_types)
    # making adjacency_matrix_one_hot [batch, edge_types, len, len]
    adjacency_matrix_one_hot = tf.transpose(adjacency_matrix_one_hot,
                                            [0, 3, 1, 2])
    # getting dot products for q_i, k_j, and e_{ij}. This assumes that for
    # edge type 0, the dot products are 0
    all_edge_logits *= adjacency_matrix_one_hot
    # logits will be [batch, length, length] after reducing along
    # axis 1 which has dimension num_edge_types.
    logits = tf.reduce_sum(all_edge_logits, axis=1)
    # ignoring edges if needed
    bias = 0
    if ignore_zero:
      bias = tf.to_float(tf.equal(adjacency_matrix, 0)) * -1e9
    logits += bias
    # getting compatibilities
    compatibility = tf.nn.softmax(logits)
        tf.expand_dims(compatibility, axis=1), None)
    # getting edge compatibilities ready to compute values.
    # after tiling, edge_compatibility will be
    # [batch, num_edge_types, length, length]
    edge_compatibility = tf.tile(
        tf.expand_dims(compatibility, axis=1), [1, num_edge_types, 1, 1])
    # computing values
    edge_compatibility *= adjacency_matrix_one_hot
    # all edge values will be [batch, num_edge_types, length, depth]
    # We also assumed that the linear transformations for edge_type 0 will
    # all be zeros. That is [batch, 0] is a length*depth tensor of 0's
    all_edge_values = tf.matmul(edge_compatibility, v)
    # reducing along the num_edge_types dimension
    output = tf.reduce_sum(all_edge_values, axis=1)
    return output
def graph_attention(q,
    """graph attention.

    q: a Tensor with shape [batch, heads, length_q, depth_k]
    k: a Tensor with shape [batch, heads, length_kv, depth_k]
    v: a Tensor with shape [batch, heads, length_kv, depth_v]
    bias: bias Tensor (see attention_bias())
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    make_image_summary: True if you want an image summary.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    adjacency_matrix: optional matrix of [batch, length, length] ids indicating
      edge type
    num_edge_types: an int indicating number of edge types
    A Tensor of shape [batch, length, depth(q)]
    with tf.variable_scope(name,
                           values=[q, k, v]) as scope:
        # [batch, num_heads, query_length, memory_length]
        logits = tf.matmul(q, k, transpose_b=True)
        if adjacency_matrix is not None:
            key_head_depth = common_layers.shape_list(q)[-1]
            adjacency_vectors = make_edge_vectors(adjacency_matrix,
            # transposing q to be [batch, length_q, heads, depth_k]
            # to allow for matmul with [batch, length_q, length_q, depth_k]
            q_t = tf.transpose(q, [0, 2, 1, 3])
            adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True)
            logits += tf.transpose(adj_logits, [0, 2, 1, 3])
            # [batch, depth, num_nodes, num_nodes]
        if bias is not None:
            logits += bias
        weights = tf.nn.softmax(logits, name="attention_weights")
        if save_weights_to is not None:
            save_weights_to[] = weights
        # dropping out the attention links for each of the heads
        weights = common_layers.dropout_with_broadcast_dims(
            weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
        if common_layers.should_generate_summaries() and make_image_summary:
            common_attention.attention_image_summary(weights, image_shapes)
        return tf.matmul(weights, v)
def dot_product_attention_mtsa(
    """Dot-product attention.

    q: Tensor with shape [..., length_q, depth_k].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    bias: bias Tensor (see attention_bias())
    dropout_rate: a float.
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    make_image_summary: True if you want an image summary.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims: an optional list of integers less than rank of q.
      Specifies in which dimensions to broadcast the dropout decisions.

    Tensor with shape [..., length_q, depth_v].
    with tf.variable_scope(name,
                           values=[q, k, v]) as scope:
        # get dim
        dim_q = q.get_shape().as_list()[-1]
        dim_k = k.get_shape().as_list()[-1]
        dim_v = v.get_shape().as_list()[-1]
        # prepare
        multi_logits_scale_factor = 1. / math.sqrt(
            dim_v) if afn_multi.startswith('scaled') else 1.
        afn_extra, afn_dot, afn_multi = afn_name2fn(afn_extra), afn_name2fn(
            afn_dot), afn_name2fn(afn_multi)
        # if bias is not None:
        #   inp_mask_1d = tf.to_float(tf.equal(bias, 0.))  # bs,1,1,vl
        #   inp_mask_1d = tf.transpose(inp_mask_1d, [0, 1, 3, 2])   # bs,1,vl,1
        # else:
        #   inp_mask_1d = None

        # token2token self attention
        dot_logits = tf.matmul(q, k, transpose_b=True)  # bs,hd,ql,vl
        if bias is not None:
            bias = common_layers.cast_like(bias, dot_logits)  # 1/bs,1,ql/1,vl
            dot_logits += bias
        e_dot_logits = afn_dot(dot_logits)  # bs,hd,ql,vl
        if bi_direction:
            head_num = v.get_shape().as_list()[1]
            ql, vl = tf.shape(q)[-2], tf.shape(v)[-2]
            assert head_num is not None
            assert head_num % 2 == 0
            ones_mat = tf.ones([ql, vl], tf.float32)
            mul_mask_fw = tf.matrix_band_part(ones_mat, -1,
                                              0)  #  Lower triangular part.
            mul_mask_bw = tf.matrix_band_part(ones_mat, 0,
                                              -1)  #  Upper triangular part.
            mul_mask_fw_tile = tf.tile(tf.expand_dims(mul_mask_fw, 0),
                                       [head_num // 2, 1, 1])
            mul_mask_bw_tile = tf.tile(tf.expand_dims(mul_mask_bw, 0),
                                       [head_num // 2, 1, 1])
            mul_mask = tf.expand_dims(tf.concat(
                [mul_mask_fw_tile, mul_mask_bw_tile], axis=0),
            e_dot_logits *= mul_mask

        # source2token self-attention
        multi_logits = multi_head_dense_layer(
            k if use_k_mtsa else v, dim_v, True,
            bias_start if afn_extra is None else 0., 'multi_logits1')
        if afn_extra is not None:  # use one extra layer for multi-dim
            multi_logits = multi_head_dense_layer(afn_extra(multi_logits),
                                                  dim_v, True, bias_start,
        e_multi_logits = afn_multi(multi_logits *
                                   multi_logits_scale_factor)  # bs,hd,vl,vd
        # if inp_mask_1d is not None:  # use mask for exp_logits
        #   e_multi_logits *= inp_mask_1d

        # mtsa
        accum_z_deno = tf.matmul(e_dot_logits, e_multi_logits)  # bs,hd,ql,vd
        accum_z_deno = tf.where(  # in case of NaN and Inf
            tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)),
            accum_z_deno, tf.ones_like(accum_z_deno))

        # attention dropout
        e_dot_logits = common_layers.dropout_with_broadcast_dims(
            math.sqrt(1. - dropout_rate),
        e_multi_logits = common_layers.dropout_with_broadcast_dims(
            math.sqrt(1. - dropout_rate),
        rep_mul_score = v * e_multi_logits  # bs,hd,vl,vd
        accum_rep_mul_score = tf.matmul(e_dot_logits,
                                        rep_mul_score)  # bs,hd,ql,vd
        # calculate the final attention results
        attn_res = accum_rep_mul_score / accum_z_deno
        # if inp_mask_1d is not None:  # use mask for output
        #   attn_res *= inp_mask_1d

        # ============ for vis =======
        weights = e_dot_logits / (tf.reduce_sum(
            e_dot_logits, axis=-1, keepdims=True, name="attention_weights") +
        if save_weights_to is not None:
            save_weights_to[] = weights
            save_weights_to[ + "/logits"] = dot_logits
        if common_layers.should_generate_summaries() and make_image_summary:
            common_attention.attention_image_summary(weights, image_shapes)
        return attn_res
