def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, node_ids): """Recv message from given msg to dst nodes. """ if reduce_function == "sum": if isinstance(msg, dict): raise TypeError("The message for build-in function" " should be Tensor not dict.") try: out_dims = msg.shape[-1] init_output = fluid.layers.fill_constant_batch_size_like( node_ids, shape=[1, out_dims], value=0, dtype="float32") init_output.stop_gradient = False output = paddle_helper.scatter_add(init_output, dst, msg) return output except TypeError as e: warnings.warn( "scatter_add is not supported with paddle version <= 1.5") def sum_func(message): return fluid.layers.sequence_pool(message, "sum") reduce_function = sum_func # convert msg into lodtensor bucketed_msg = op.nested_lod_reset(msg, bucketing_index) # Check dim for bucketed_msg equal to out_dims output = reduce_function(bucketed_msg) out_dims = output.shape[-1] init_output = fluid.layers.fill_constant_batch_size_like( node_ids, shape=[1, out_dims], value=0, dtype="float32") init_output.stop_gradient = False output = fluid.layers.scatter(init_output, uniq_dst, output) return output
def graph_pooling(gw, node_feat, pool_type): """Implementation of graph pooling This is an implementation of graph pooling Args: gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) node_feat: A tensor with shape (num_nodes, feature_size). pool_type: The type of pooling ("sum", "average" , "min") Return: A tensor with shape (num_graph, hidden_size) """ graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod) graph_feat = fluid.layers.sequence_pool(graph_feat, pool_type) return graph_feat
def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, num_edges): """Recv message from given msg to dst nodes. """ if reduce_function == "sum": if isinstance(msg, dict): raise TypeError("The message for build-in function" " should be Tensor not dict.") try: out_dim = msg.shape[-1] init_output = L.fill_constant(shape=[num_nodes, out_dim], value=0, dtype=msg.dtype) init_output.stop_gradient = False empty_msg_flag = L.cast(num_edges > 0, dtype=msg.dtype) msg = msg * empty_msg_flag output = paddle_helper.scatter_add(init_output, dst, msg) return output except TypeError as e: warnings.warn( "scatter_add is not supported with paddle version <= 1.5") def sum_func(message): return L.sequence_pool(message, "sum") reduce_function = sum_func bucketed_msg = op.nested_lod_reset(msg, bucketing_index) output = reduce_function(bucketed_msg) output_dim = output.shape[-1] empty_msg_flag = L.cast(num_edges > 0, dtype=output.dtype) output = output * empty_msg_flag init_output = L.fill_constant(shape=[num_nodes, output_dim], value=0, dtype=output.dtype) init_output.stop_gradient = True final_output = L.scatter(init_output, uniq_dst, output) return final_output
def graph_norm(gw, feature): """Implementation of graph normalization Reference Paper: BENCHMARKING GRAPH NEURAL NETWORKS Each node features is divied by sqrt(num_nodes) per graphs. Args: gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) feature: A tensor with shape (num_nodes, hidden_size) Return: A tensor with shape (num_nodes, hidden_size) """ nodes = L.fill_constant([gw.num_nodes, 1], dtype="float32", value=1.0) norm = graph_pooling(gw, nodes, pool_type="sum") norm = L.sqrt(norm) feature_lod = op.nested_lod_reset(feature, gw.graph_lod) norm = L.sequence_expand_as(norm, feature_lod) norm.stop_gradient = True return feature_lod / norm