示例#1
0
 def ogbl_data_to_general_static_graph(
     cls,
     ogbl_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]],
     heterogeneous_edges: _typing.Mapping[_typing.Tuple[
         str, str, str], _typing.Union[torch.Tensor, _typing.Tuple[
             torch.Tensor,
             _typing.Optional[_typing.Mapping[str, torch.Tensor]]]]] = ...,
     nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...,
     graph_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...
 ) -> GeneralStaticGraph:
     return GeneralStaticGraphGenerator.create_heterogeneous_static_graph(
         {
             '':
             dict([(target_data_key,
                    torch.from_numpy(ogbl_data[source_data_key]).squeeze())
                   for source_data_key, target_data_key in
                   nodes_data_key_mapping.items()])
         }, heterogeneous_edges,
         dict([(target_data_key,
                torch.from_numpy(ogbl_data[source_data_key]).squeeze())
               for source_data_key, target_data_key in
               graph_data_key_mapping.items()]) if isinstance(
                   graph_data_key_mapping, _typing.Mapping) else ...)
示例#2
0
 def __init__(self, path: str):
     ogbl_dataset = LinkPropPredDataset("ogbl-ddi", path)
     edge_split = ogbl_dataset.get_edge_split()
     super(OGBLDDIDataset, self).__init__([
         GeneralStaticGraphGenerator.create_heterogeneous_static_graph(
             {'': {
                 '_NID': torch.arange(ogbl_dataset[0]['num_nodes'])
             }}, {
                 ('', '', ''):
                 torch.from_numpy(ogbl_dataset[0]['edge_index']),
                 ('', 'train_pos_edge', ''):
                 torch.from_numpy(edge_split['train']['edge']),
                 ('', 'val_pos_edge', ''):
                 torch.from_numpy(edge_split['valid']['edge']),
                 ('', 'val_neg_edge', ''):
                 torch.from_numpy(edge_split['valid']['edge_neg']),
                 ('', 'test_pos_edge', ''):
                 torch.from_numpy(edge_split['test']['edge']),
                 ('', 'test_neg_edge', ''):
                 torch.from_numpy(edge_split['test']['edge_neg'])
             })
     ])