Esempio n. 1
0
def test_lightning_hetero_node_data():
    import pytorch_lightning as pl

    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = DBLP(root)
    data = dataset[0]
    shutil.rmtree(root)

    model = LinearHeteroNodeModule(data['author'].num_features,
                                   int(data['author'].y.max()) + 1)

    gpus = torch.cuda.device_count()
    strategy = pl.plugins.DDPSpawnPlugin(find_unused_parameters=False)

    trainer = pl.Trainer(strategy=strategy,
                         gpus=gpus,
                         max_epochs=5,
                         log_every_n_steps=1)
    datamodule = LightningNodeData(data,
                                   loader='neighbor',
                                   num_neighbors=[5],
                                   batch_size=32,
                                   num_workers=3)
    old_x = data['author'].x.clone()
    trainer.fit(model, datamodule)
    new_x = data['author'].x
    offset = 0
    offset += gpus * 2  # `sanity`
    offset += 5 * gpus * math.ceil(400 / (gpus * 32))  # `train`
    offset += 5 * gpus * math.ceil(400 / (gpus * 32))  # `val`
    assert torch.all(new_x > (old_x + offset - 4))  # Ensure shared data.
    assert trainer._data_connector._val_dataloader_source.is_defined()
    assert trainer._data_connector._test_dataloader_source.is_defined()
Esempio n. 2
0
def load_dataset(root: str, name: str, *args, **kwargs) -> Dataset:
    r"""Returns a variety of datasets according to :obj:`name`."""
    if 'karate' in name.lower():
        from torch_geometric.datasets import KarateClub
        return KarateClub(*args, **kwargs)
    if name.lower() in ['cora', 'citeseer', 'pubmed']:
        from torch_geometric.datasets import Planetoid
        path = osp.join(root, 'Planetoid', name)
        return Planetoid(path, name, *args, **kwargs)
    if name in ['BZR', 'ENZYMES', 'IMDB-BINARY', 'MUTAG']:
        from torch_geometric.datasets import TUDataset
        path = osp.join(root, 'TUDataset')
        return TUDataset(path, name, *args, **kwargs)
    if name in ['ego-facebook', 'soc-Slashdot0811', 'wiki-vote']:
        from torch_geometric.datasets import SNAPDataset
        path = osp.join(root, 'SNAPDataset')
        return SNAPDataset(path, name, *args, **kwargs)
    if name.lower() in ['bashapes']:
        from torch_geometric.datasets import BAShapes
        return BAShapes(*args, **kwargs)
    if name.lower() in ['dblp']:
        from torch_geometric.datasets import DBLP
        path = osp.join(root, 'DBLP')
        return DBLP(path, *args, **kwargs)
    if name in ['citationCiteseer', 'illc1850']:
        from torch_geometric.datasets import SuiteSparseMatrixCollection
        path = osp.join(root, 'SuiteSparseMatrixCollection')
        return SuiteSparseMatrixCollection(path, name=name, *args, **kwargs)

    raise NotImplementedError
import os.path as osp

import torch
import torch.nn.functional as F

from torch_geometric.datasets import DBLP
from torch_geometric.nn import HeteroConv, Linear, SAGEConv

path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP')
path = '/data/datasets/DBLP'
dataset = DBLP(path)
data = dataset[0]
print(data)

# We initialize conference node features with a single feature.
data['conference'].x = torch.ones(data['conference'].num_nodes, 1)


class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                edge_type: SAGEConv((-1, -1), hidden_channels)
                for edge_type in metadata[1]
            })
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)