def forward(
        self,
        x_dict: Dict[NodeType, Tensor],
        edge_index_dict: Dict[EdgeType, Adj],
        **kwargs_dict,
    ) -> Dict[NodeType, Tensor]:
        r"""
        Args:
            x_dict (Dict[str, Tensor]): A dictionary holding node feature
                information for each individual node type.
            edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary
                holding graph connectivity information for each individual
                edge type.
            **kwargs_dict (optional): Additional forward arguments
                of individual :class:`torch_geometric.nn.conv.MessagePassing`
                layers.
                For example, if a specific GNN layer at edge type
                :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a
                forward argument, then you can pass them to
                :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via
                :obj:`edge_attr_dict = { edge_type: edge_attr }`.
        """
        out_dict = defaultdict(list)
        for edge_type, edge_index in edge_index_dict.items():
            src, rel, dst = edge_type

            str_edge_type = '__'.join(edge_type)
            if str_edge_type not in self.convs:
                continue

            kwargs = {
                arg[:-5]: value[edge_type]  # `{*}_dict`
                for arg, value in kwargs_dict.items() if edge_type in value
            }

            conv = self.convs[str_edge_type]

            if src == dst:
                out = conv(x=x_dict[src], edge_index=edge_index, **kwargs)
            else:
                out = conv(x=(x_dict[src], x_dict[dst]),
                           edge_index=edge_index,
                           **kwargs)

            out_dict[dst].append(out)

        for key, value in out_dict.items():
            out_dict[key] = group(value, self.aggr)

        return out_dict
Exemplo n.º 2
0
    def forward(
        self,
        x_dict: Dict[NodeType, Tensor],
        edge_index_dict: Union[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor]],
        edge_weight_dict: Optional[Dict[EdgeType, Tensor]] = None,
        edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,
    ) -> Dict[NodeType, Tensor]:
        r"""
        Args:
            x_dict (Dict[str, Tensor]): A dictionary holding node feature
                information for each individual node type.
            edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary
                holding graph connectivity information for each individual
                edge type.
            edge_weight_dict (Dict[Tuple[str, str, str], Tensor], optional): A
                dictionary holding one-dimensional edge weight information
                for individual edge types. (default: :obj:`None`)
            edge_attr_dict (Dict[Tuple[str, str, str], Tensor], optional): A
                dictionary holding multi-dimensional edge feature information
                for individual edge types. (default: :obj:`None`)
        """

        out_dict = defaultdict(list)
        for key in self.keys:
            src, _, dst = key
            conv = self.convs['__'.join(key)]

            kwargs = {}
            if edge_weight_dict is not None and key in edge_weight_dict:
                kwargs['edge_weight'] = edge_weight_dict[key]
            if edge_weight_dict is not None and key in edge_attr_dict:
                kwargs['edge_attr'] = edge_attr_dict[key]

            if src == dst:
                out = conv(x=x_dict[src],
                           edge_index=edge_index_dict[key],
                           **kwargs)
            else:
                out = conv(x=(x_dict[src], x_dict[dst]),
                           edge_index=edge_index_dict[key],
                           **kwargs)

            out_dict[dst].append(out)

        for key, values in out_dict.items():
            out_dict[key] = group(values, self.aggr)

        return out_dict