Exemplo n.º 1
0
 def ogbl_data_to_general_static_graph(
     cls,
     ogbl_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]],
     heterogeneous_edges: _typing.Mapping[_typing.Tuple[
         str, str, str], _typing.Union[torch.Tensor, _typing.Tuple[
             torch.Tensor,
             _typing.Optional[_typing.Mapping[str, torch.Tensor]]]]] = ...,
     nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...,
     graph_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...
 ) -> GeneralStaticGraph:
     return GeneralStaticGraphGenerator.create_heterogeneous_static_graph(
         {
             '':
             dict([(target_data_key,
                    torch.from_numpy(ogbl_data[source_data_key]).squeeze())
                   for source_data_key, target_data_key in
                   nodes_data_key_mapping.items()])
         }, heterogeneous_edges,
         dict([(target_data_key,
                torch.from_numpy(ogbl_data[source_data_key]).squeeze())
               for source_data_key, target_data_key in
               graph_data_key_mapping.items()]) if isinstance(
                   graph_data_key_mapping, _typing.Mapping) else ...)
Exemplo n.º 2
0
 def ogbn_data_to_general_static_graph(
     cls,
     ogbn_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]],
     nodes_label: np.ndarray = ...,
     nodes_label_key: str = ...,
     train_index: _typing.Optional[np.ndarray] = ...,
     val_index: _typing.Optional[np.ndarray] = ...,
     test_index: _typing.Optional[np.ndarray] = ...,
     nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...,
     edges_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...,
     graph_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...
 ) -> GeneralStaticGraph:
     homogeneous_static_graph: GeneralStaticGraph = (
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(
             dict([(target_key, torch.from_numpy(ogbn_data[source_key]))
                   for source_key, target_key in
                   nodes_data_key_mapping.items()]),
             torch.from_numpy(ogbn_data['edge_index']),
             dict([(target_key, torch.from_numpy(ogbn_data[source_key]))
                   for source_key, target_key in
                   edges_data_key_mapping.items()]) if isinstance(
                       edges_data_key_mapping, _typing.Mapping) else ...,
             dict([(target_key, torch.from_numpy(ogbn_data[source_key]))
                   for source_key, target_key in
                   graph_data_key_mapping.items()]) if isinstance(
                       graph_data_key_mapping, _typing.Mapping) else ...))
     if isinstance(nodes_label, np.ndarray) and isinstance(
             nodes_label_key, str):
         if ' ' in nodes_label_key:
             raise ValueError("Illegal nodes label key")
         homogeneous_static_graph.nodes.data[nodes_label_key] = (
             torch.from_numpy(nodes_label.squeeze()).squeeze())
     if isinstance(train_index, np.ndarray):
         homogeneous_static_graph.nodes.data['train_mask'] = index_to_mask(
             torch.from_numpy(train_index), ogbn_data['num_nodes'])
     if isinstance(val_index, np.ndarray):
         homogeneous_static_graph.nodes.data['val_mask'] = index_to_mask(
             torch.from_numpy(val_index), ogbn_data['num_nodes'])
     if isinstance(test_index, np.ndarray):
         homogeneous_static_graph.nodes.data['test_mask'] = index_to_mask(
             torch.from_numpy(test_index), ogbn_data['num_nodes'])
     return homogeneous_static_graph
Exemplo n.º 3
0
 def __init__(self, path: str):
     ogbl_dataset = LinkPropPredDataset("ogbl-ddi", path)
     edge_split = ogbl_dataset.get_edge_split()
     super(OGBLDDIDataset, self).__init__([
         GeneralStaticGraphGenerator.create_heterogeneous_static_graph(
             {'': {
                 '_NID': torch.arange(ogbl_dataset[0]['num_nodes'])
             }}, {
                 ('', '', ''):
                 torch.from_numpy(ogbl_dataset[0]['edge_index']),
                 ('', 'train_pos_edge', ''):
                 torch.from_numpy(edge_split['train']['edge']),
                 ('', 'val_pos_edge', ''):
                 torch.from_numpy(edge_split['valid']['edge']),
                 ('', 'val_neg_edge', ''):
                 torch.from_numpy(edge_split['valid']['edge_neg']),
                 ('', 'test_pos_edge', ''):
                 torch.from_numpy(edge_split['test']['edge']),
                 ('', 'test_neg_edge', ''):
                 torch.from_numpy(edge_split['test']['edge_neg'])
             })
     ])