コード例 #1
0
def sgc(x, edge_index, edge_weight, K, kernel, bias=None, renorm=True, improved=False, cache=None):
    """
    Functional API for Simple Graph Convolution (SGC).

    :param x: Tensor, shape: [num_nodes, num_features], node features
    :param edge_index: Tensor, shape: [2, num_edges], edge information
    :param edge_weight: Tensor or None, shape: [num_edges]
    :param K: Number of hops.(default: :obj:`1`)
    :param kernel: Tensor, shape: [num_features, num_output_features], weight.
    :param bias: Tensor, shape: [num_output_features], bias.
    :param renorm: Whether use renormalization trick (https://arxiv.org/pdf/1609.02907.pdf).
    :param improved: Whether use improved GCN or not.
    :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict.
    :return: Updated node features (x), shape: [num_nodes, num_features]
    """
    updated_edge_index, normed_edge_weight = gcn_norm_edge(edge_index, x.shape[0], edge_weight,
                                                           renorm, improved, cache)

    h = x
    for _ in range(K):
        h = aggregate_neighbors(
            h,
            updated_edge_index,
            normed_edge_weight,
            gcn_mapper,
            sum_reducer,
            identity_updater
        )

    h = h @ kernel

    if bias is not None:
        h += bias
    return h
コード例 #2
0
def gcn(x, edge_index, edge_weight, kernel, bias=None, activation=None, improved=False, cache=None):
    """

    :param x: Tensor, shape: [num_nodes, num_features], node features
    :param edge_index: Tensor, shape: [2, num_edges], edge information
    :param edge_weight: Tensor or None, shape: [num_edges]
    :param kernel: Tensor, shape: [num_features, num_output_features], weight
    :param bias: Tensor, shape: [num_output_features], bias
    :param activation: Activation function to use.
    :param improved: Whether use improved GCN or not.
    :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict.
    :return: Updated node features (x), shape: [num_nodes, num_output_features]
    """

    updated_edge_index, normed_edge_weight = gcn_norm_edge(edge_index, x.shape[0], edge_weight,
                                                           improved, cache)
    x = x @ kernel
    h = aggregate_neighbors(
        x, updated_edge_index, normed_edge_weight,
        gcn_mapper,
        sum_reducer,
        identity_updater
    )

    if bias is not None:
        h += bias

    if activation is not None:
        h = activation(h)

    return h
コード例 #3
0
ファイル: gin.py プロジェクト: zzhzaihq/tf_geometric
def gin(x,
        edge_index,
        edge_weight,
        mlp_model,
        eps=0.0,
        activation=None,
        cache=None):
    """

    :param x: Tensor, shape: [num_nodes, num_features], node features
    :param edge_index: Tensor, shape: [2, num_edges], edge information
    :param edge_weight: Tensor or None, shape: [num_edges]
    :param eps: float, optional, (default: :obj:`0.`).
    :param mlp_model: A neural network (multi-layer perceptrons).
    :param kernel: Tensor, shape: [num_features, num_output_features], weight
    :param cache: A dict for caching A' for GIN. Different graph should not share the same cache dict.
    :return: Updated node features (x), shape: [num_nodes, num_output_features]
    """

    h = aggregate_neighbors(x, edge_index, edge_weight, identity_mapper,
                            sum_reducer, identity_updater)

    h = gin_updater(x, h, eps)

    h = mlp_model(h)

    if activation is not None:
        h = activation(h)

    return h
コード例 #4
0
ファイル: chebynet.py プロジェクト: HoneyXjk/tf_geometric
def chebynet(x,
             edge_index,
             edge_weight,
             k,
             kernels,
             bias=None,
             activation=None,
             normalization_type="sym",
             use_dynamic_lambda_max=False,
             cache=None):
    num_nodes = tf.shape(x)[0]
    # lambda_max = chebynet_compute_lambda_max(x, edge_index, edge_weight, normalization_type, cache=cache)

    norm_edge_index, norm_edge_weight = chebynet_norm_edge(
        edge_index,
        num_nodes,
        edge_weight,
        normalization_type,
        use_dynamic_lambda_max=use_dynamic_lambda_max,
        cache=cache)

    T0_x = x
    T1_x = x
    out = tf.matmul(T0_x, kernels[0])

    if k > 1:
        T1_x = aggregate_neighbors(x, norm_edge_index, norm_edge_weight,
                                   gcn_mapper, sum_reducer, identity_updater)
        out += tf.matmul(T1_x, kernels[1])

    for i in range(2, k):
        T2_x = aggregate_neighbors(T1_x, norm_edge_index, norm_edge_weight,
                                   gcn_mapper, sum_reducer,
                                   identity_updater)  ##L^T_{k-1}(L^)
        T2_x = 2.0 * T2_x - T0_x
        out += tf.matmul(T2_x, kernels[i])

        T0_x, T1_x = T1_x, T2_x

    if bias is not None:
        out += bias

    if activation is not None:
        out = activation(out)

    return out
コード例 #5
0
ファイル: gat.py プロジェクト: renzhenwen/tf_geometric
def gat(x,
        edge_index,
        query_kernel,
        query_bias,
        query_activation,
        key_kernel,
        key_bias,
        key_activation,
        kernel,
        bias=None,
        activation=None,
        num_heads=1,
        drop_rate=0.0,
        training=False):

    num_nodes = x.shape[0]

    # self-attention
    edge_index, edge_weight = add_self_loop_edge(edge_index, num_nodes)

    row, col = edge_index

    Q = tf.gather(x, row) @ query_kernel + query_bias
    Q = query_activation(Q)

    K = tf.gather(x, col) @ key_kernel + key_bias
    K = key_activation(K)

    V = x @ kernel

    # xxxxx_ denotes the multi-head style stuff
    Q_ = tf.concat(tf.split(Q, num_heads, axis=-1), axis=0)
    K_ = tf.concat(tf.split(K, num_heads, axis=-1), axis=0)
    V_ = tf.concat(tf.split(V, num_heads, axis=-1), axis=0)
    edge_index_ = tf.concat(
        [edge_index + i * num_nodes for i in range(num_heads)], axis=1)

    att_score_ = tf.reduce_sum(Q_ * K_, axis=-1)
    normed_att_score_ = segment_softmax(att_score_, edge_index_[0],
                                        num_nodes * num_heads)

    if training and drop_rate > 0.0:
        normed_att_score_ = tf.compat.v2.nn.dropout(normed_att_score_,
                                                    drop_rate)

    h_ = aggregate_neighbors(V_, edge_index_, normed_att_score_, gcn_mapper,
                             sum_reducer, identity_updater)

    h = tf.concat(tf.split(h_, num_heads, axis=0), axis=-1)

    if bias is not None:
        h += bias

    if activation is not None:
        h = activation(h)

    return h
コード例 #6
0
ファイル: gcn.py プロジェクト: Miraclemin/tf_geometric
def gcn(x,
        edge_index,
        edge_weight,
        kernel,
        bias=None,
        activation=None,
        renorm=True,
        improved=False,
        cache=None):
    """
    Functional API for Graph Convolutional Networks.

    :param x: Tensor, shape: [num_nodes, num_features], node features
    :param edge_index: Tensor, shape: [2, num_edges], edge information
    :param edge_weight: Tensor or None, shape: [num_edges]
    :param kernel: Tensor, shape: [num_features, num_output_features], weight
    :param bias: Tensor, shape: [num_output_features], bias
    :param activation: Activation function to use.
    :param renorm: Whether use renormalization trick (https://arxiv.org/pdf/1609.02907.pdf).
    :param improved: Whether use improved GCN or not.
    :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict.
        To use @tf_utils.function with gcn, you should cache the noremd edge information before the first call of the gcn.

        - (1) If you're using OOP APIs tfg.layers.GCN:

              gcn_layer.cache_normed_edge(graph)

        - (2) If you're using functional API tfg.nn.gcn:

              from tf_geometric.nn.conv.gcn import gcn_cache_normed_edge
              gcn_cache_normed_edge(graph)

    :return: Updated node features (x), shape: [num_nodes, num_output_features]
    """

    num_nodes = tf.shape(x)[0]
    updated_edge_index, normed_edge_weight = gcn_norm_edge(
        edge_index, num_nodes, edge_weight, renorm, improved, cache)

    x = x @ kernel

    h = aggregate_neighbors(x,
                            updated_edge_index,
                            normed_edge_weight,
                            gcn_mapper,
                            sum_reducer,
                            identity_updater,
                            num_nodes=num_nodes)

    if bias is not None:
        h += bias

    if activation is not None:
        h = activation(h)

    return h
コード例 #7
0
ファイル: chebynet.py プロジェクト: rahul5757/tf_geometric
def chebynet(x,
             edge_index,
             edge_weight,
             K,
             lambda_max,
             kernel,
             bias=None,
             activation=None,
             normalization_type=None):
    num_nodes = x.shape[0]
    norm_edge_index, norm_edge_weight = chebynet_norm_edge(
        edge_index,
        num_nodes,
        edge_weight,
        lambda_max,
        normalization_type=normalization_type)

    T0_x = x
    T1_x = x
    out = tf.matmul(T0_x, kernel[0])

    if K > 1:
        T1_x = aggregate_neighbors(x, norm_edge_index, norm_edge_weight,
                                   gcn_mapper, sum_reducer, identity_updater)
        out += tf.matmul(T1_x, kernel[1])

    for i in range(2, K):
        T2_x = aggregate_neighbors(T1_x, norm_edge_index, norm_edge_weight,
                                   gcn_mapper, sum_reducer,
                                   identity_updater)  ##L^T_{k-1}(L^)
        T2_x = 2.0 * T2_x - T0_x
        out += tf.matmul(T2_x, kernel[i])

        T0_x, T1_x = T1_x, T2_x

    if bias is not None:
        out += bias

    if activation is not None:
        out = activation(out)

    return out
コード例 #8
0
ファイル: tagcn.py プロジェクト: rahul5757/tf_geometric
def tagcn(x,
          edge_index,
          edge_weight,
          K,
          kernel,
          bias=None,
          activation=None,
          renorm=False,
          improved=False,
          cache=None):
    """
    Functional API for Topology Adaptive Graph Convolutional Network (TAGCN).

    :param x: Tensor, shape: [num_nodes, num_features], node features.
    :param edge_index: Tensor, shape: [2, num_edges], edge information.
    :param edge_weight: Tensor or None, shape: [num_edges].
    :param K: Number of hops.(default: :obj:`3`)
    :param kernel: Tensor, shape: [num_features, num_output_features], weight.
    :param bias: Tensor, shape: [num_output_features], bias.
    :param activation: Activation function to use.
    :param renorm: Whether use renormalization trick (https://arxiv.org/pdf/1609.02907.pdf).
    :param improved: Whether use improved GCN or not.
    :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict.
    :return: Updated node features (x), shape: [num_nodes, num_output_features]
    """

    xs = [x]
    updated_edge_index, normed_edge_weight = gcn_norm_edge(
        edge_index, x.shape[0], edge_weight, renorm, improved, cache)
    for k in range(K):
        h = aggregate_neighbors(xs[-1], updated_edge_index, normed_edge_weight,
                                gcn_mapper, sum_reducer, identity_updater)

        xs.append(h)

    h = tf.concat(xs, axis=-1)

    out = h @ kernel
    if bias is not None:
        out += bias

    if activation is not None:
        out = activation(out)

    return out
コード例 #9
0
ファイル: gcn.py プロジェクト: renzhenwen/tf_geometric
def gcn(x, edge_index, edge_weight, kernel, bias=None, activation=None, improved=False, cache=None):
    updated_edge_index, normed_edge_weight = gcn_norm_edge(edge_index, x.shape[0], edge_weight,
                                                           improved, cache)
    x = x @ kernel
    h = aggregate_neighbors(
        x, updated_edge_index, normed_edge_weight,
        gcn_mapper,
        sum_reducer,
        identity_updater
    )

    if bias is not None:
        h += bias

    if activation is not None:
        h = activation(h)

    return h
コード例 #10
0
ファイル: gin.py プロジェクト: zhaoquan219/tf_geometric
def gin(x, edge_index, edge_weight, mlp_model, eps=0.0, training=None):
    """

    :param x: Tensor, shape: [num_nodes, num_features], node features
    :param edge_index: Tensor, shape: [2, num_edges], edge information
    :param edge_weight: Tensor or None, shape: [num_edges]
    :param mlp_model: A neural network (multi-layer perceptrons).
    :param eps: float, optional, (default: :obj:`0.`).
    :param training: Whether currently executing in training or inference mode.
    :return: Updated node features (x), shape: [num_nodes, num_output_features]
    """

    h = aggregate_neighbors(x, edge_index, edge_weight, identity_mapper,
                            sum_reducer, identity_updater)

    h = gin_updater(x, h, eps)

    h = mlp_model(h, training=training)

    return h
コード例 #11
0
def appnp(x,
          edge_index,
          edge_weight,
          kernels,
          biases,
          dense_activation=tf.nn.relu,
          activation=None,
          num_iterations=2,
          alpha=0.15,
          dense_drop_rate=0.0,
          edge_drop_rate=0.0,
          cache=None,
          training=False):
    """
    Functional API for Approximate Personalized Propagation of Neural Predictions (APPNP).

    :param x: Tensor, shape: [num_nodes, num_features], node features
    :param edge_index: Tensor, shape: [2, num_edges], edge information
    :param edge_weight: Tensor or None, shape: [num_edges]
    :param kernels: List[Tensor], shape of each Tensor: [num_features, num_output_features], weights
    :param biases: List[Tensor], shape of each Tensor: [num_output_features], biases
    :param dense_activation: Activation function to use for the dense layers,
        except for the last dense layer, which will not be activated.
    :param activation: Activation function to use for the output.
    :param num_iterations: Number of propagation power iterations.
    :param alpha: Teleport Probability.
    :param dense_drop_rate: Dropout rate for the input of every dense layer.
    :param edge_drop_rate: Dropout rate for the edges/adj used for propagation.
    :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict.
        To use @tf_utils.function with gcn, you should cache the noremd edge information before the first call of the gcn.

        - (1) If you're using OOP APIs tfg.layers.GCN:

              gcn_layer.cache_normed_edge(graph)

        - (2) If you're using functional API tfg.nn.gcn:

              from tf_geometric.nn.conv.gcn import gcn_cache_normed_edge
              gcn_cache_normed_edge(graph)

    :param training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).
    :return: Updated node features (x), shape: [num_nodes, num_output_features]
    """

    num_nodes = tf.shape(x)[0]
    updated_edge_index, normed_edge_weight = gcn_norm_edge(edge_index,
                                                           num_nodes,
                                                           edge_weight,
                                                           cache=cache)

    num_dense_layers = len(kernels)

    h = x
    for i, (kernel, bias) in enumerate(zip(kernels, biases)):
        if training and dense_drop_rate > 0.0:
            h = tf.compat.v2.nn.dropout(h, dense_drop_rate)
        h = h @ kernel + bias
        if dense_activation is not None and i < num_dense_layers - 1:
            h = dense_activation(h)

    if training and edge_drop_rate > 0.0:
        normed_edge_weight = tf.compat.v2.nn.dropout(normed_edge_weight,
                                                     edge_drop_rate)

    prop_h = h

    for i in range(num_iterations):
        prop_h = aggregate_neighbors(prop_h,
                                     updated_edge_index,
                                     normed_edge_weight,
                                     gcn_mapper,
                                     sum_reducer,
                                     identity_updater,
                                     num_nodes=num_nodes)
        prop_h = prop_h * (1.0 - alpha) + h * alpha

    if activation is not None:
        prop_h = activation(prop_h)

    return prop_h
コード例 #12
0
ファイル: asap.py プロジェクト: rahul5757/tf_geometric
def asap(x,
         edge_index,
         edge_weight,
         node_graph_index,
         attention_gcn_kernel,
         attention_gcn_bias,
         attention_query_kernel,
         attention_query_bias,
         attention_score_kernel,
         attention_score_bias,
         le_conv_self_kernel,
         le_conv_self_bias,
         le_conv_aggr_self_kernel,
         le_conv_aggr_self_bias,
         le_conv_aggr_neighbor_kernel,
         le_conv_aggr_neighbor_bias,
         K=None,
         ratio=None,
         le_conv_activation=tf.nn.sigmoid,
         drop_rate=0.0,
         training=None,
         cache=None):
    """
    Functional API for ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representation

    :param x: Tensor, shape: [num_nodes, num_features], node features
    :param edge_index: Tensor, shape: [2, num_edges], edge information
    :param edge_weight: Tensor or None, shape: [num_edges]
    :param node_graph_index: Tensor/NDArray, shape: [num_nodes], graph index for each node
    :param K: Keep top K targets for each source
    :param ratio: Keep num_targets * ratio targets for each source
    :param le_conv_activation: Activation to use for node_score before multiplying node_features with node_score
    :param training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).
    :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict.
    :return: [pooled_x, pooled_edge_index, pooled_edge_weight, pooled_node_graph_index]
    """

    num_nodes = tf.shape(x)[0]
    # num_graphs = tf.reduce_max(node_graph_index) + 1

    edge_index, edge_weight = remove_self_loop_edge(edge_index, edge_weight)
    edge_index_with_self_loop, edge_weight_with_self_loop = add_self_loop_edge(
        edge_index, num_nodes=num_nodes, edge_weight=edge_weight)

    row_with_self_loop, col_with_self_loop = edge_index_with_self_loop[
        0], edge_index_with_self_loop[1]

    attention_h = gcn(x,
                      edge_index,
                      edge_weight,
                      attention_gcn_kernel,
                      attention_gcn_bias,
                      cache=cache)

    # max_pool -> query
    attention_query = aggregate_neighbors(attention_h,
                                          edge_index_with_self_loop,
                                          None,
                                          mapper=identity_mapper,
                                          reducer=max_reducer,
                                          updater=identity_updater,
                                          num_nodes=num_nodes)

    attention_query = attention_query @ attention_query_kernel + attention_query_bias

    repeated_attention_query = tf.gather(attention_query, row_with_self_loop)
    repeated_attention_h = tf.gather(attention_h, col_with_self_loop)

    attention_score_h = tf.concat(
        [repeated_attention_query, repeated_attention_h], axis=-1)
    attention_score = attention_score_h @ attention_score_kernel + attention_score_bias
    attention_score = tf.nn.leaky_relu(attention_score, alpha=0.2)

    normed_attention_score = segment_softmax(attention_score,
                                             row_with_self_loop, num_nodes)
    if training and drop_rate > 0:
        normed_attention_score = tf.compat.v2.nn.dropout(
            normed_attention_score, rate=drop_rate)

    # nodes are clusters
    cluster_h = aggregate_neighbors(x,
                                    edge_index_with_self_loop,
                                    tf.reshape(normed_attention_score, [-1]),
                                    gcn_mapper,
                                    sum_reducer,
                                    identity_updater,
                                    num_nodes=num_nodes)

    node_score = le_conv(cluster_h,
                         edge_index,
                         edge_weight,
                         le_conv_self_kernel,
                         le_conv_self_bias,
                         le_conv_aggr_self_kernel,
                         le_conv_aggr_self_bias,
                         le_conv_aggr_neighbor_kernel,
                         le_conv_aggr_neighbor_bias,
                         activation=None)

    topk_node_index = topk_pool(node_graph_index, node_score, K=K, ratio=ratio)
    topk_node_score = tf.gather(node_score, topk_node_index)
    if le_conv_activation is not None:
        topk_node_score = le_conv_activation(topk_node_score)

    pooled_x = tf.gather(cluster_h, topk_node_index) * topk_node_score

    num_clusters = tf.shape(topk_node_index)[0]
    # node->cluster
    cluster_reverse_index = tf.cast(tf.fill([num_nodes], -1), tf.int32)
    cluster_reverse_index = tf.tensor_scatter_nd_update(
        cluster_reverse_index, tf.expand_dims(topk_node_index, axis=-1),
        tf.range(num_clusters))

    # row, col = edge_index[0], edge_index[1]
    assign_row = tf.gather(cluster_reverse_index, row_with_self_loop)
    assign_mask = tf.greater_equal(assign_row, 0)

    assign_row = tf.boolean_mask(assign_row, assign_mask)
    assign_col = tf.boolean_mask(col_with_self_loop, assign_mask)
    assign_edge_index = tf.stack([assign_row, assign_col], axis=0)

    assign_edge_weight = tf.boolean_mask(normed_attention_score, assign_mask)
    assign_edge_weight = tf.reshape(assign_edge_weight, [-1])
    assign_edge_weight = tf.stop_gradient(assign_edge_weight)

    # Coarsen in a large BatchGraph.
    _, pooled_edge_index, pooled_edge_weight = cluster_pool(
        None,
        edge_index_with_self_loop,
        edge_weight_with_self_loop,
        assign_edge_index,
        assign_edge_weight,
        num_clusters,
        num_nodes=num_nodes)

    pooled_edge_index, pooled_edge_weight = remove_self_loop_edge(
        pooled_edge_index, pooled_edge_weight)
    pooled_edge_index, pooled_edge_weight = add_self_loop_edge(
        pooled_edge_index, num_clusters, pooled_edge_weight)

    pooled_node_graph_index = tf.gather(node_graph_index, topk_node_index)

    return pooled_x, pooled_edge_index, pooled_edge_weight, pooled_node_graph_index
コード例 #13
0
def gat(x,
        edge_index,
        query_kernel,
        query_bias,
        query_activation,
        key_kernel,
        key_bias,
        key_activation,
        kernel,
        bias=None,
        activation=None,
        num_heads=1,
        drop_rate=0.0,
        training=False):
    """

    :param x: Tensor, shape: [num_nodes, num_features], node features
    :param edge_index: Tensor, shape: [2, num_edges], edge information
    :param query_kernel: Tensor, shape: [num_features, num_query_features], weight for Q in attention
    :param query_bias: Tensor, shape: [num_query_features], bias for Q in attention
    :param query_activation: Activation function for Q in attention.
    :param key_kernel: Tensor, shape: [num_features, num_key_features], weight for K in attention
    :param key_bias: Tensor, shape: [num_key_features], bias for K in attention
    :param key_activation: Activation function for K in attention.
    :param kernel: Tensor, shape: [num_features, num_output_features], weight
    :param bias: Tensor, shape: [num_output_features], bias
    :param activation: Activation function to use.
    :param num_heads: Number of attention heads.
    :param drop_rate: Dropout rate.
    :param training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).
    :return: Updated node features (x), shape: [num_nodes, num_output_features]
    """

    num_nodes = x.shape[0]

    # self-attention
    edge_index, edge_weight = add_self_loop_edge(edge_index, num_nodes)

    row, col = edge_index

    Q = query_activation(x @ query_kernel + query_bias)
    Q = tf.gather(Q, row)

    K = key_activation(x @ key_kernel + key_bias)
    K = tf.gather(K, col)

    V = x @ kernel

    # xxxxx_ denotes the multi-head style stuff
    Q_ = tf.concat(tf.split(Q, num_heads, axis=-1), axis=0)
    K_ = tf.concat(tf.split(K, num_heads, axis=-1), axis=0)
    V_ = tf.concat(tf.split(V, num_heads, axis=-1), axis=0)
    edge_index_ = tf.concat(
        [edge_index + i * num_nodes for i in range(num_heads)], axis=1)

    att_score_ = tf.reduce_sum(Q_ * K_, axis=-1)
    normed_att_score_ = segment_softmax(att_score_, edge_index_[0],
                                        num_nodes * num_heads)

    if training and drop_rate > 0.0:
        normed_att_score_ = tf.compat.v2.nn.dropout(normed_att_score_,
                                                    drop_rate)

    h_ = aggregate_neighbors(V_, edge_index_, normed_att_score_, gcn_mapper,
                             sum_reducer, identity_updater)

    h = tf.concat(tf.split(h_, num_heads, axis=0), axis=-1)

    if bias is not None:
        h += bias

    if activation is not None:
        h = activation(h)

    return h
コード例 #14
0
def gat(x,
        edge_index,
        query_kernel,
        query_bias,
        query_activation,
        key_kernel,
        key_bias,
        key_activation,
        kernel,
        bias=None,
        activation=None,
        num_heads=1,
        split_value_heads=True,
        drop_rate=0.0,
        training=False):
    """

    :param x: Tensor, shape: [num_nodes, num_features], node features
    :param edge_index: Tensor, shape: [2, num_edges], edge information
    :param query_kernel: Tensor, shape: [num_features, num_query_features], weight for Q in attention
    :param query_bias: Tensor, shape: [num_query_features], bias for Q in attention
    :param query_activation: Activation function for Q in attention.
    :param key_kernel: Tensor, shape: [num_features, num_key_features], weight for K in attention
    :param key_bias: Tensor, shape: [num_key_features], bias for K in attention
    :param key_activation: Activation function for K in attention.
    :param kernel: Tensor, shape: [num_features, num_output_features], weight
    :param bias: Tensor, shape: [num_output_features], bias
    :param activation: Activation function to use.
    :param num_heads: Number of attention heads.
    :param split_value_heads: Boolean. If true, split V as value attention heads, and then concatenate them as output.
        Else, num_heads replicas of V are used as value attention heads, and the mean of them are used as output.
    :param drop_rate: Dropout rate.
    :param training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).
    :return: Updated node features (x), shape: [num_nodes, num_output_features]
    """

    num_nodes = tf.shape(x)[0]

    # self-attention
    edge_index, edge_weight = add_self_loop_edge(edge_index, num_nodes)

    row, col = edge_index[0], edge_index[1]

    Q = x @ query_kernel + query_bias
    if query_activation is not None:
        Q = query_activation(Q)
    Q = tf.gather(Q, row)

    K = x @ key_kernel + key_bias
    if key_activation is not None:
        K = key_activation(K)
    K = tf.gather(K, col)

    V = x @ kernel

    # xxxxx_ denotes the multi-head style stuff
    Q_ = tf.concat(tf.split(Q, num_heads, axis=-1), axis=0)
    K_ = tf.concat(tf.split(K, num_heads, axis=-1), axis=0)
    # splited queries and keys are modeled as virtual vertices
    qk_edge_index_ = tf.concat(
        [edge_index + i * num_nodes for i in range(num_heads)], axis=1)

    scale = tf.math.sqrt(tf.cast(tf.shape(Q_)[-1], tf.float32))
    att_score_ = tf.reduce_sum(Q_ * K_ / scale, axis=-1)
    normed_att_score_ = segment_softmax(att_score_, qk_edge_index_[0],
                                        num_nodes * num_heads)

    if training and drop_rate > 0.0:
        normed_att_score_ = tf.compat.v2.nn.dropout(normed_att_score_,
                                                    drop_rate)

    if split_value_heads:
        V_ = tf.concat(tf.split(V, num_heads, axis=-1), axis=0)
        edge_index_ = qk_edge_index_
    else:
        V_ = V
        edge_index_ = tf.tile(edge_index, [1, num_heads])

    h_ = aggregate_neighbors(V_, edge_index_, normed_att_score_, gcn_mapper,
                             sum_reducer, identity_updater)

    if split_value_heads:
        h = tf.concat(tf.split(h_, num_heads, axis=0), axis=-1)
    else:
        h = h_ / num_heads

    if bias is not None:
        h += bias

    if activation is not None:
        h = activation(h)

    return h