def ogbl_data_to_general_static_graph( cls, ogbl_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]], heterogeneous_edges: _typing.Mapping[_typing.Tuple[ str, str, str], _typing.Union[torch.Tensor, _typing.Tuple[ torch.Tensor, _typing.Optional[_typing.Mapping[str, torch.Tensor]]]]] = ..., nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., graph_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ... ) -> GeneralStaticGraph: return GeneralStaticGraphGenerator.create_heterogeneous_static_graph( { '': dict([(target_data_key, torch.from_numpy(ogbl_data[source_data_key]).squeeze()) for source_data_key, target_data_key in nodes_data_key_mapping.items()]) }, heterogeneous_edges, dict([(target_data_key, torch.from_numpy(ogbl_data[source_data_key]).squeeze()) for source_data_key, target_data_key in graph_data_key_mapping.items()]) if isinstance( graph_data_key_mapping, _typing.Mapping) else ...)
def __init__(self, path: str): ogbl_dataset = LinkPropPredDataset("ogbl-ddi", path) edge_split = ogbl_dataset.get_edge_split() super(OGBLDDIDataset, self).__init__([ GeneralStaticGraphGenerator.create_heterogeneous_static_graph( {'': { '_NID': torch.arange(ogbl_dataset[0]['num_nodes']) }}, { ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), ('', 'train_pos_edge', ''): torch.from_numpy(edge_split['train']['edge']), ('', 'val_pos_edge', ''): torch.from_numpy(edge_split['valid']['edge']), ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), ('', 'test_pos_edge', ''): torch.from_numpy(edge_split['test']['edge']), ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) }) ])