Ejemplo n.º 1
0
def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, node_ids):
    """Recv message from given msg to dst nodes.
    """
    if reduce_function == "sum":
        if isinstance(msg, dict):
            raise TypeError("The message for build-in function"
                            " should be Tensor not dict.")

        try:
            out_dims = msg.shape[-1]
            init_output = fluid.layers.fill_constant_batch_size_like(
                node_ids, shape=[1, out_dims], value=0, dtype="float32")
            init_output.stop_gradient = False
            output = paddle_helper.scatter_add(init_output, dst, msg)
            return output
        except TypeError as e:
            warnings.warn(
                "scatter_add is not supported with paddle version <= 1.5")

            def sum_func(message):
                return fluid.layers.sequence_pool(message, "sum")

            reduce_function = sum_func

    # convert msg into lodtensor
    bucketed_msg = op.nested_lod_reset(msg, bucketing_index)
    # Check dim for bucketed_msg equal to out_dims
    output = reduce_function(bucketed_msg)
    out_dims = output.shape[-1]

    init_output = fluid.layers.fill_constant_batch_size_like(
        node_ids, shape=[1, out_dims], value=0, dtype="float32")
    init_output.stop_gradient = False
    output = fluid.layers.scatter(init_output, uniq_dst, output)
    return output
Ejemplo n.º 2
0
def graph_pooling(gw, node_feat, pool_type):
    """Implementation of graph pooling 

    This is an implementation of graph pooling

    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        node_feat: A tensor with shape (num_nodes, feature_size).

        pool_type: The type of pooling ("sum", "average" , "min")

    Return:
        A tensor with shape (num_graph, hidden_size)
    """
    graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod)
    graph_feat = fluid.layers.sequence_pool(graph_feat, pool_type)
    return graph_feat
Ejemplo n.º 3
0
def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
         num_edges):
    """Recv message from given msg to dst nodes.
    """
    if reduce_function == "sum":
        if isinstance(msg, dict):
            raise TypeError("The message for build-in function"
                            " should be Tensor not dict.")

        try:
            out_dim = msg.shape[-1]
            init_output = L.fill_constant(shape=[num_nodes, out_dim],
                                          value=0,
                                          dtype=msg.dtype)
            init_output.stop_gradient = False
            empty_msg_flag = L.cast(num_edges > 0, dtype=msg.dtype)
            msg = msg * empty_msg_flag
            output = paddle_helper.scatter_add(init_output, dst, msg)
            return output
        except TypeError as e:
            warnings.warn(
                "scatter_add is not supported with paddle version <= 1.5")

            def sum_func(message):
                return L.sequence_pool(message, "sum")

            reduce_function = sum_func

    bucketed_msg = op.nested_lod_reset(msg, bucketing_index)
    output = reduce_function(bucketed_msg)
    output_dim = output.shape[-1]

    empty_msg_flag = L.cast(num_edges > 0, dtype=output.dtype)
    output = output * empty_msg_flag

    init_output = L.fill_constant(shape=[num_nodes, output_dim],
                                  value=0,
                                  dtype=output.dtype)
    init_output.stop_gradient = True
    final_output = L.scatter(init_output, uniq_dst, output)
    return final_output
Ejemplo n.º 4
0
def graph_norm(gw, feature):
    """Implementation of graph normalization
   
    Reference Paper: BENCHMARKING GRAPH NEURAL NETWORKS
   
    Each node features is divied by sqrt(num_nodes) per graphs.

    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        feature: A tensor with shape (num_nodes, hidden_size)

    Return:
        A tensor with shape (num_nodes, hidden_size)
    """
    nodes = L.fill_constant([gw.num_nodes, 1], dtype="float32", value=1.0)
    norm = graph_pooling(gw, nodes, pool_type="sum")
    norm = L.sqrt(norm)
    feature_lod = op.nested_lod_reset(feature, gw.graph_lod)
    norm = L.sequence_expand_as(norm, feature_lod)
    norm.stop_gradient = True
    return feature_lod / norm