Exemplo n.º 1
0
def send(src, dst, nfeat, efeat, message_func):
    """Send message from src to dst.
    """
    src_feat = op.RowReader(nfeat, src)
    dst_feat = op.RowReader(nfeat, dst)
    msg = message_func(src_feat, dst_feat, efeat)
    return msg
Exemplo n.º 2
0
    def recv(self, reduce_func, msg, recv_mode="dst"):
        """Recv message and aggregate the message by reduce_func

        The UDF reduce_func function should has the following format.

        .. code-block:: python

            def reduce_func(msg):
                '''
                    Args:

                        msg: A LodTensor or a dictionary of LodTensor whose batch_size
                             is equals to the number of unique dst nodes.

                    Return:

                        It should return a tensor with shape (batch_size, out_dims). The
                        batch size should be the same as msg.
                '''
                pass

        Args:

            msg: A tensor or a dictionary of tensor created by send function..

            reduce_func: A callable UDF reduce function.

        Return:

            A tensor with shape (num_nodes, out_dims). The output for nodes with 
            no message will be zeros.
        """
        if not self._is_tensor:
            raise ValueError("You must call Graph.tensor()")

        if not isinstance(msg, dict):
            raise TypeError(
                "The input of msg should be a dict, but receives a %s" %
                (type(msg)))

        if not callable(reduce_func):
            raise TypeError("reduce_func should be callable")

        src, dst, eid = self.sorted_edges(sort_by=recv_mode)

        msg = op.RowReader(msg, eid)

        if recv_mode == "dst":
            uniq_ind, segment_ids = paddle.unique(dst, return_inverse=True)
        elif recv_mode == "src":
            uniq_ind, segment_ids = paddle.unique(src, return_inverse=True)

        bucketed_msg = Message(msg, segment_ids)
        output = reduce_func(bucketed_msg)
        output_dim = output.shape[-1]
        init_output = paddle.zeros(shape=[self._num_nodes, output_dim],
                                   dtype=output.dtype)
        final_output = scatter(init_output, uniq_ind, output)

        return final_output
Exemplo n.º 3
0
    def send(
        self,
        message_func,
        src_feat=None,
        dst_feat=None,
        edge_feat=None,
        node_feat=None,
    ):
        """Send message from all src nodes to dst nodes.

        The UDF message function should has the following format.

        .. code-block:: python

            def message_func(src_feat, dst_feat, edge_feat):
                '''
                    Args:
                        src_feat: the node feat dict attached to the src nodes.
                        dst_feat: the node feat dict attached to the dst nodes.
                        edge_feat: the edge feat dict attached to the
                                   corresponding (src, dst) edges.

                    Return:
                        It should return a tensor or a dictionary of tensor. And each tensor
                        should have a shape of (num_edges, dims).
                '''
                return {'msg': src_feat['h']}

        Args:
            message_func: UDF function.
            src_feat: a dict {name: tensor,} to build src node feat
            dst_feat: a dict {name: tensor,} to build dst node feat
            node_feat: a dict {name: tensor,} to build both src and dst node feat
            edge_feat: a dict {name: tensor,} to build edge feat

        Return:
            A dictionary of tensor representing the message. Each of the values
            in the dictionary has a shape (num_edges, dim) which should be collected
            by :code:`recv` function.
        """
        if self._is_tensor:
            if (src_feat is not None
                    or dst_feat is not None) and node_feat is not None:
                raise ValueError(
                    "Can not use src/dst feat and node feat at the same time")

            src_feat_temp = {}
            dst_feat_temp = {}
            if node_feat is not None:
                assert isinstance(node_feat,
                                  dict), "The input node_feat must be a dict"
                src_feat_temp.update(node_feat)
                dst_feat_temp.update(node_feat)
            else:
                if src_feat is not None:
                    assert isinstance(
                        src_feat, dict), "The input src_feat must be a dict"
                    src_feat_temp.update(src_feat)

                if dst_feat is not None:
                    assert isinstance(
                        dst_feat, dict), "The input dst_feat must be a dict"
                    dst_feat_temp.update(dst_feat)

            edge_feat_temp = {}
            if edge_feat is not None:
                assert isinstance(edge_feat,
                                  dict), "The input edge_feat must be a dict"
                edge_feat_temp.update(edge_feat)

            src = self.edges[:, 0]
            dst = self.edges[:, 1]

            src_feat = op.RowReader(src_feat_temp, src)
            dst_feat = op.RowReader(dst_feat_temp, dst)
            msg = message_func(src_feat, dst_feat, edge_feat_temp)

            if not isinstance(msg, dict):
                raise TypeError(
                    "The outputs of the %s function is expected to be a dict, but got %s" \
                            % (message_func.__name__, type(msg)))
            return msg
        else:
            raise ValueError("You must call Graph.tensor() first")