def __call__(self, data: HeteroData) -> HeteroData: edge_types = data.edge_types # save original edge types data.metapath_dict = {} for j, metapath in enumerate(self.metapaths): for edge_type in metapath: assert data._to_canonical( edge_type) in edge_types, f"'{edge_type}' not present" edge_type = metapath[0] adj1 = SparseTensor.from_edge_index( edge_index=data[edge_type].edge_index, sparse_sizes=data[edge_type].size()) for i, edge_type in enumerate(metapath[1:]): adj2 = SparseTensor.from_edge_index( edge_index=data[edge_type].edge_index, sparse_sizes=data[edge_type].size()) adj1 = adj1 @ adj2 row, col, _ = adj1.coo() new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1]) data[new_edge_type].edge_index = torch.vstack([row, col]) data.metapath_dict[new_edge_type] = metapath if self.drop_orig_edges: for i in edge_types: if self.keep_same_node_type and i[0] == i[-1]: continue else: del data[i] return data
def __call__(self, data: HeteroData) -> HeteroData: edge_types = data.edge_types # save original edge types data.metapath_dict = {} for j, metapath in enumerate(self.metapaths): for edge_type in metapath: assert data._to_canonical( edge_type) in edge_types, f"'{edge_type}' not present" edge_type = metapath[0] edge_weight = self._get_edge_weight(data, edge_type) adj1 = SparseTensor.from_edge_index( edge_index=data[edge_type].edge_index, sparse_sizes=data[edge_type].size(), edge_attr=edge_weight) if self.max_sample is not None: adj1 = self.sample_adj(adj1) for i, edge_type in enumerate(metapath[1:]): edge_weight = self._get_edge_weight(data, edge_type) adj2 = SparseTensor.from_edge_index( edge_index=data[edge_type].edge_index, sparse_sizes=data[edge_type].size(), edge_attr=edge_weight) adj1 = adj1 @ adj2 if self.max_sample is not None: adj1 = self.sample_adj(adj1) row, col, edge_weight = adj1.coo() new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1]) data[new_edge_type].edge_index = torch.vstack([row, col]) if self.weighted: data[new_edge_type].edge_weight = edge_weight data.metapath_dict[new_edge_type] = metapath if self.drop_orig_edges: for i in edge_types: if self.keep_same_node_type and i[0] == i[-1]: continue else: del data[i] # remove nodes not connected by any edge type. if self.drop_unconnected_nodes: new_edge_types = data.edge_types node_types = data.node_types connected_nodes = set() for i in new_edge_types: connected_nodes.add(i[0]) connected_nodes.add(i[-1]) for node in node_types: if node not in connected_nodes: del data[node] return data