def ogbl_data_to_general_static_graph( cls, ogbl_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]], heterogeneous_edges: _typing.Mapping[_typing.Tuple[ str, str, str], _typing.Union[torch.Tensor, _typing.Tuple[ torch.Tensor, _typing.Optional[_typing.Mapping[str, torch.Tensor]]]]] = ..., nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., graph_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ... ) -> GeneralStaticGraph: return GeneralStaticGraphGenerator.create_heterogeneous_static_graph( { '': dict([(target_data_key, torch.from_numpy(ogbl_data[source_data_key]).squeeze()) for source_data_key, target_data_key in nodes_data_key_mapping.items()]) }, heterogeneous_edges, dict([(target_data_key, torch.from_numpy(ogbl_data[source_data_key]).squeeze()) for source_data_key, target_data_key in graph_data_key_mapping.items()]) if isinstance( graph_data_key_mapping, _typing.Mapping) else ...)
def ogbn_data_to_general_static_graph( cls, ogbn_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]], nodes_label: np.ndarray = ..., nodes_label_key: str = ..., train_index: _typing.Optional[np.ndarray] = ..., val_index: _typing.Optional[np.ndarray] = ..., test_index: _typing.Optional[np.ndarray] = ..., nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., edges_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., graph_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ... ) -> GeneralStaticGraph: homogeneous_static_graph: GeneralStaticGraph = ( GeneralStaticGraphGenerator.create_homogeneous_static_graph( dict([(target_key, torch.from_numpy(ogbn_data[source_key])) for source_key, target_key in nodes_data_key_mapping.items()]), torch.from_numpy(ogbn_data['edge_index']), dict([(target_key, torch.from_numpy(ogbn_data[source_key])) for source_key, target_key in edges_data_key_mapping.items()]) if isinstance( edges_data_key_mapping, _typing.Mapping) else ..., dict([(target_key, torch.from_numpy(ogbn_data[source_key])) for source_key, target_key in graph_data_key_mapping.items()]) if isinstance( graph_data_key_mapping, _typing.Mapping) else ...)) if isinstance(nodes_label, np.ndarray) and isinstance( nodes_label_key, str): if ' ' in nodes_label_key: raise ValueError("Illegal nodes label key") homogeneous_static_graph.nodes.data[nodes_label_key] = ( torch.from_numpy(nodes_label.squeeze()).squeeze()) if isinstance(train_index, np.ndarray): homogeneous_static_graph.nodes.data['train_mask'] = index_to_mask( torch.from_numpy(train_index), ogbn_data['num_nodes']) if isinstance(val_index, np.ndarray): homogeneous_static_graph.nodes.data['val_mask'] = index_to_mask( torch.from_numpy(val_index), ogbn_data['num_nodes']) if isinstance(test_index, np.ndarray): homogeneous_static_graph.nodes.data['test_mask'] = index_to_mask( torch.from_numpy(test_index), ogbn_data['num_nodes']) return homogeneous_static_graph
def __init__(self, path: str): ogbl_dataset = LinkPropPredDataset("ogbl-ddi", path) edge_split = ogbl_dataset.get_edge_split() super(OGBLDDIDataset, self).__init__([ GeneralStaticGraphGenerator.create_heterogeneous_static_graph( {'': { '_NID': torch.arange(ogbl_dataset[0]['num_nodes']) }}, { ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), ('', 'train_pos_edge', ''): torch.from_numpy(edge_split['train']['edge']), ('', 'val_pos_edge', ''): torch.from_numpy(edge_split['valid']['edge']), ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), ('', 'test_pos_edge', ''): torch.from_numpy(edge_split['test']['edge']), ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) }) ])