Esempio n. 1
0
 def __init__(self, path: str):
     data = _GTNDataSource(path, "gtn-imdb")[0]
     if _backend.DependentBackend.is_dgl():
         super(GTNIMDBDataset, self).__init__([
             GeneralStaticGraphGenerator.create_homogeneous_static_graph(
                 {
                     'feat': getattr(data, 'x'),
                     'label': getattr(data, 'y'),
                     'pos': getattr(data, 'pos'),
                     'train_mask': getattr(data, 'train_mask'),
                     'val_mask': getattr(data, 'val_mask'),
                     'test_mask': getattr(data, 'test_mask')
                 }, getattr(data, 'edge_index'))
         ])
     elif _backend.DependentBackend.is_pyg():
         super(GTNIMDBDataset, self).__init__([
             GeneralStaticGraphGenerator.create_homogeneous_static_graph(
                 {
                     'x': getattr(data, 'x'),
                     'y': getattr(data, 'y'),
                     'pos': getattr(data, 'pos'),
                     'train_mask': getattr(data, 'train_mask'),
                     'val_mask': getattr(data, 'val_mask'),
                     'test_mask': getattr(data, 'test_mask')
                 }, getattr(data, 'edge_index'))
         ])
Esempio n. 2
0
 def __init__(self, path: str):
     filename: str = "POS"
     url = "http://snap.stanford.edu/node2vec/"
     data = _MATLABMatrix(path, filename, url)[0]
     if _backend.DependentBackend.is_dgl():
         super(WIKIPEDIADataset, self).__init__([
             GeneralStaticGraphGenerator.create_homogeneous_static_graph(
                 {'label': data.y}, data.edge_index,
                 {'attr': data.edge_attr})
         ])
     elif _backend.DependentBackend.is_pyg():
         super(WIKIPEDIADataset, self).__init__([
             GeneralStaticGraphGenerator.create_homogeneous_static_graph(
                 {'y': data.y}, data.edge_index, {'attr': data.edge_attr})
         ])
Esempio n. 3
0
 def __init__(self, path: str):
     filename: str = "BlogCatalog".lower()
     url: str = "http://leitang.net/code/social-dimension/data/"
     data = _MATLABMatrix(path, filename, url)[0]
     if _backend.DependentBackend.is_dgl():
         super(BlogCatalogDataset, self).__init__([
             GeneralStaticGraphGenerator.create_homogeneous_static_graph(
                 {'label': data.y}, data.edge_index,
                 {'edge_attr': data.edge_attr})
         ])
     elif _backend.DependentBackend.is_pyg():
         super(BlogCatalogDataset, self).__init__([
             GeneralStaticGraphGenerator.create_homogeneous_static_graph(
                 {'y': data.y}, data.edge_index,
                 {'edge_attr': data.edge_attr})
         ])
Esempio n. 4
0
 def __init__(self, path: str):
     ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path)
     idx_split: _typing.Mapping[str,
                                np.ndarray] = ogbl_dataset.get_idx_split()
     train_index: _typing.Any = idx_split['train'].tolist()
     test_index: _typing.Any = idx_split['test'].tolist()
     val_index: _typing.Any = idx_split['valid'].tolist()
     super(OGBGCode2Dataset, self).__init__([
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(({
             "feat":
             torch.from_numpy(data['node_feat']),
             "node_is_attributed":
             torch.from_numpy(data["node_is_attributed"]),
             "node_dfs_order":
             torch.from_numpy(data["node_dfs_order"]),
             "node_depth":
             torch.from_numpy(data["node_depth"])
         } if _backend.DependentBackend.is_dgl() else {
             "x":
             torch.from_numpy(data['node_feat']),
             "node_is_attributed":
             torch.from_numpy(data["node_is_attributed"]),
             "node_dfs_order":
             torch.from_numpy(data["node_dfs_order"]),
             "node_depth":
             torch.from_numpy(data["node_depth"])
         }), torch.from_numpy(data['edge_index']))
         for data, label in ogbl_dataset
     ], train_index, val_index, test_index)
Esempio n. 5
0
 def __init__(self, path: str):
     pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "COLLAB")
     if hasattr(pyg_dataset, "__data_list__"):
         delattr(pyg_dataset, "__data_list__")
     if hasattr(pyg_dataset, "_data_list"):
         delattr(pyg_dataset, "_data_list")
     super(COLLABDataset, self).__init__([
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(
             {}, pyg_data.edge_index, graph_data={'y': pyg_data.y})
         for pyg_data in pyg_dataset
     ])
Esempio n. 6
0
 def __init__(self, path: str):
     train_dataset = PPI(os.path.join(path, '_pyg'), 'train')
     if hasattr(train_dataset, "__data_list__"):
         delattr(train_dataset, "__data_list__")
     if hasattr(train_dataset, "_data_list"):
         delattr(train_dataset, "_data_list")
     val_dataset = PPI(os.path.join(path, '_pyg'), 'val')
     if hasattr(val_dataset, "__data_list__"):
         delattr(val_dataset, "__data_list__")
     if hasattr(val_dataset, "_data_list"):
         delattr(val_dataset, "_data_list")
     test_dataset = PPI(os.path.join(path, '_pyg'), 'test')
     if hasattr(test_dataset, "__data_list__"):
         delattr(test_dataset, "__data_list__")
     if hasattr(test_dataset, "_data_list"):
         delattr(test_dataset, "_data_list")
     train_index = range(len(train_dataset))
     val_index = range(len(train_dataset),
                       len(train_dataset) + len(val_dataset))
     test_index = range(
         len(train_dataset) + len(val_dataset),
         len(train_dataset) + len(val_dataset) + len(test_dataset))
     super(PPIDataset, self).__init__([
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(
             {
                 'x': data.x,
                 'y': data.y
             }, data.edge_index) for data in train_dataset
     ] + [
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(
             {
                 'x': data.x,
                 'y': data.y
             }, data.edge_index) for data in val_dataset
     ] + [
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(
             {
                 'x': data.x,
                 'y': data.y
             }, data.edge_index) for data in test_dataset
     ], train_index, val_index, test_index)
Esempio n. 7
0
 def __init__(self, path: str):
     pyg_dataset = Coauthor(os.path.join(path, '_pyg'), "CS")
     if hasattr(pyg_dataset, "__data_list__"):
         delattr(pyg_dataset, "__data_list__")
     if hasattr(pyg_dataset, "_data_list"):
         delattr(pyg_dataset, "_data_list")
     pyg_data = pyg_dataset[0]
     static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph(
         {
             'x': pyg_data.x,
             'y': pyg_data.y
         }, pyg_data.edge_index)
     super(CoauthorCSDataset, self).__init__([static_graph])
Esempio n. 8
0
 def __init__(self, path: str):
     pyg_dataset = ModelNet(
         os.path.join(path, '_pyg'),
         '40',
         False,
         pre_transform=torch_geometric.transforms.FaceToEdge())
     if hasattr(pyg_dataset, "__data_list__"):
         delattr(pyg_dataset, "__data_list__")
     if hasattr(pyg_dataset, "_data_list"):
         delattr(pyg_dataset, "_data_list")
     super(ModelNet40TestDataset, self).__init__([
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(
             {'pos': pyg_data.pos},
             pyg_data.edge_index,
             graph_data={'y': pyg_data.y}) for pyg_data in pyg_dataset
     ])
Esempio n. 9
0
    def __init__(self, path: str):
        pyg_dataset = Reddit(os.path.join(path, '_pyg'))
        if hasattr(pyg_dataset, "__data_list__"):
            delattr(pyg_dataset, "__data_list__")
        if hasattr(pyg_dataset, "_data_list"):
            delattr(pyg_dataset, "_data_list")
        pyg_data = pyg_dataset[0]

        static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph(
            {
                'x': pyg_data.x,
                'y': pyg_data.y,
                'train_mask': getattr(pyg_data, 'train_mask'),
                'val_mask': getattr(pyg_data, 'val_mask'),
                'test_mask': getattr(pyg_data, 'test_mask')
            }, pyg_data.edge_index)
        super(RedditDataset, self).__init__([static_graph])
Esempio n. 10
0
 def ogbn_data_to_general_static_graph(
     cls,
     ogbn_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]],
     nodes_label: np.ndarray = ...,
     nodes_label_key: str = ...,
     train_index: _typing.Optional[np.ndarray] = ...,
     val_index: _typing.Optional[np.ndarray] = ...,
     test_index: _typing.Optional[np.ndarray] = ...,
     nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...,
     edges_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...,
     graph_data_key_mapping: _typing.Optional[_typing.Mapping[str,
                                                              str]] = ...
 ) -> GeneralStaticGraph:
     homogeneous_static_graph: GeneralStaticGraph = (
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(
             dict([(target_key, torch.from_numpy(ogbn_data[source_key]))
                   for source_key, target_key in
                   nodes_data_key_mapping.items()]),
             torch.from_numpy(ogbn_data['edge_index']),
             dict([(target_key, torch.from_numpy(ogbn_data[source_key]))
                   for source_key, target_key in
                   edges_data_key_mapping.items()]) if isinstance(
                       edges_data_key_mapping, _typing.Mapping) else ...,
             dict([(target_key, torch.from_numpy(ogbn_data[source_key]))
                   for source_key, target_key in
                   graph_data_key_mapping.items()]) if isinstance(
                       graph_data_key_mapping, _typing.Mapping) else ...))
     if isinstance(nodes_label, np.ndarray) and isinstance(
             nodes_label_key, str):
         if ' ' in nodes_label_key:
             raise ValueError("Illegal nodes label key")
         homogeneous_static_graph.nodes.data[nodes_label_key] = (
             torch.from_numpy(nodes_label.squeeze()).squeeze())
     if isinstance(train_index, np.ndarray):
         homogeneous_static_graph.nodes.data['train_mask'] = index_to_mask(
             torch.from_numpy(train_index), ogbn_data['num_nodes'])
     if isinstance(val_index, np.ndarray):
         homogeneous_static_graph.nodes.data['val_mask'] = index_to_mask(
             torch.from_numpy(val_index), ogbn_data['num_nodes'])
     if isinstance(test_index, np.ndarray):
         homogeneous_static_graph.nodes.data['test_mask'] = index_to_mask(
             torch.from_numpy(test_index), ogbn_data['num_nodes'])
     return homogeneous_static_graph
Esempio n. 11
0
 def __init__(self, path: str):
     ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path)
     idx_split: _typing.Mapping[str,
                                np.ndarray] = ogbl_dataset.get_idx_split()
     train_index: _typing.Any = idx_split['train'].tolist()
     test_index: _typing.Any = idx_split['test'].tolist()
     val_index: _typing.Any = idx_split['valid'].tolist()
     super(OGBGPPADataset, self).__init__([
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(
             {'_NID': torch.arange(data['num_nodes'])},
             torch.from_numpy(data['edge_index']),
             {'edge_feat': torch.from_numpy(data['edge_feat'])},
             ({
                 'label': torch.from_numpy(label)
             } if _backend.DependentBackend.is_dgl() else {
                 'y': torch.from_numpy(label)
             })) for data, label in ogbl_dataset
     ], train_index, val_index, test_index)
Esempio n. 12
0
 def __init__(self, path: str):
     pyg_dataset = QM9(os.path.join(path, '_pyg'))
     if hasattr(pyg_dataset, "__data_list__"):
         delattr(pyg_dataset, "__data_list__")
     if hasattr(pyg_dataset, "_data_list"):
         delattr(pyg_dataset, "_data_list")
     super(QM9Dataset, self).__init__([
         GeneralStaticGraphGenerator.create_homogeneous_static_graph(
             {
                 'x': data.x,
                 'pos': data.pos,
                 'z': data.z
             },
             data.edge_index,
             edges_data={'edge_attr': data.edge_attr},
             graph_data={
                 'idx': data.idx,
                 'y': data.y
             }) for data in pyg_dataset
     ])