def test_lightning_hetero_node_data(): import pytorch_lightning as pl root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) dataset = DBLP(root) data = dataset[0] shutil.rmtree(root) model = LinearHeteroNodeModule(data['author'].num_features, int(data['author'].y.max()) + 1) gpus = torch.cuda.device_count() strategy = pl.plugins.DDPSpawnPlugin(find_unused_parameters=False) trainer = pl.Trainer(strategy=strategy, gpus=gpus, max_epochs=5, log_every_n_steps=1) datamodule = LightningNodeData(data, loader='neighbor', num_neighbors=[5], batch_size=32, num_workers=3) old_x = data['author'].x.clone() trainer.fit(model, datamodule) new_x = data['author'].x offset = 0 offset += gpus * 2 # `sanity` offset += 5 * gpus * math.ceil(400 / (gpus * 32)) # `train` offset += 5 * gpus * math.ceil(400 / (gpus * 32)) # `val` assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data. assert trainer._data_connector._val_dataloader_source.is_defined() assert trainer._data_connector._test_dataloader_source.is_defined()
def load_dataset(root: str, name: str, *args, **kwargs) -> Dataset: r"""Returns a variety of datasets according to :obj:`name`.""" if 'karate' in name.lower(): from torch_geometric.datasets import KarateClub return KarateClub(*args, **kwargs) if name.lower() in ['cora', 'citeseer', 'pubmed']: from torch_geometric.datasets import Planetoid path = osp.join(root, 'Planetoid', name) return Planetoid(path, name, *args, **kwargs) if name in ['BZR', 'ENZYMES', 'IMDB-BINARY', 'MUTAG']: from torch_geometric.datasets import TUDataset path = osp.join(root, 'TUDataset') return TUDataset(path, name, *args, **kwargs) if name in ['ego-facebook', 'soc-Slashdot0811', 'wiki-vote']: from torch_geometric.datasets import SNAPDataset path = osp.join(root, 'SNAPDataset') return SNAPDataset(path, name, *args, **kwargs) if name.lower() in ['bashapes']: from torch_geometric.datasets import BAShapes return BAShapes(*args, **kwargs) if name.lower() in ['dblp']: from torch_geometric.datasets import DBLP path = osp.join(root, 'DBLP') return DBLP(path, *args, **kwargs) if name in ['citationCiteseer', 'illc1850']: from torch_geometric.datasets import SuiteSparseMatrixCollection path = osp.join(root, 'SuiteSparseMatrixCollection') return SuiteSparseMatrixCollection(path, name=name, *args, **kwargs) raise NotImplementedError
import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import DBLP from torch_geometric.nn import HeteroConv, Linear, SAGEConv path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP') path = '/data/datasets/DBLP' dataset = DBLP(path) data = dataset[0] print(data) # We initialize conference node features with a single feature. data['conference'].x = torch.ones(data['conference'].num_nodes, 1) class HeteroGNN(torch.nn.Module): def __init__(self, metadata, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv({ edge_type: SAGEConv((-1, -1), hidden_channels) for edge_type in metadata[1] }) self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels)