def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Adj], **kwargs_dict, ) -> Dict[NodeType, Tensor]: r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding node feature information for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary holding graph connectivity information for each individual edge type. **kwargs_dict (optional): Additional forward arguments of individual :class:`torch_geometric.nn.conv.MessagePassing` layers. For example, if a specific GNN layer at edge type :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a forward argument, then you can pass them to :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via :obj:`edge_attr_dict = { edge_type: edge_attr }`. """ out_dict = defaultdict(list) for edge_type, edge_index in edge_index_dict.items(): src, rel, dst = edge_type str_edge_type = '__'.join(edge_type) if str_edge_type not in self.convs: continue kwargs = { arg[:-5]: value[edge_type] # `{*}_dict` for arg, value in kwargs_dict.items() if edge_type in value } conv = self.convs[str_edge_type] if src == dst: out = conv(x=x_dict[src], edge_index=edge_index, **kwargs) else: out = conv(x=(x_dict[src], x_dict[dst]), edge_index=edge_index, **kwargs) out_dict[dst].append(out) for key, value in out_dict.items(): out_dict[key] = group(value, self.aggr) return out_dict
def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Union[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor]], edge_weight_dict: Optional[Dict[EdgeType, Tensor]] = None, edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, ) -> Dict[NodeType, Tensor]: r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding node feature information for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary holding graph connectivity information for each individual edge type. edge_weight_dict (Dict[Tuple[str, str, str], Tensor], optional): A dictionary holding one-dimensional edge weight information for individual edge types. (default: :obj:`None`) edge_attr_dict (Dict[Tuple[str, str, str], Tensor], optional): A dictionary holding multi-dimensional edge feature information for individual edge types. (default: :obj:`None`) """ out_dict = defaultdict(list) for key in self.keys: src, _, dst = key conv = self.convs['__'.join(key)] kwargs = {} if edge_weight_dict is not None and key in edge_weight_dict: kwargs['edge_weight'] = edge_weight_dict[key] if edge_weight_dict is not None and key in edge_attr_dict: kwargs['edge_attr'] = edge_attr_dict[key] if src == dst: out = conv(x=x_dict[src], edge_index=edge_index_dict[key], **kwargs) else: out = conv(x=(x_dict[src], x_dict[dst]), edge_index=edge_index_dict[key], **kwargs) out_dict[dst].append(out) for key, values in out_dict.items(): out_dict[key] = group(values, self.aggr) return out_dict