def test_lightning_dataset(get_dataset, strategy_type):
    import pytorch_lightning as pl

    dataset = get_dataset(name='MUTAG').shuffle()
    train_dataset = dataset[:50]
    val_dataset = dataset[50:80]
    test_dataset = dataset[80:90]

    devices = 1 if strategy_type is None else torch.cuda.device_count()
    if strategy_type == 'ddp_spawn':
        strategy = pl.strategies.DDPSpawnStrategy(find_unused_parameters=False)
    else:
        strategy = None

    model = LinearGraphModule(dataset.num_features, 64, dataset.num_classes)

    trainer = pl.Trainer(strategy=strategy,
                         accelerator='gpu',
                         devices=devices,
                         max_epochs=1,
                         log_every_n_steps=1)
    datamodule = LightningDataset(train_dataset,
                                  val_dataset,
                                  test_dataset,
                                  batch_size=5,
                                  num_workers=3)
    old_x = train_dataset.data.x.clone()
    assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
                               'val_dataset=MUTAG(30), '
                               'test_dataset=MUTAG(10), batch_size=5, '
                               'num_workers=3, pin_memory=True, '
                               'persistent_workers=True)')
    trainer.fit(model, datamodule)
    new_x = train_dataset.data.x
    offset = 10 + 6 + 2 * devices  # `train_steps` + `val_steps` + `sanity`
    assert torch.all(new_x > (old_x + offset - 4))  # Ensure shared data.
    if strategy_type is None:
        assert trainer._data_connector._val_dataloader_source.is_defined()
        assert trainer._data_connector._test_dataloader_source.is_defined()

    # Test with `val_dataset=None` and `test_dataset=None`:
    if strategy_type == 'ddp_spawn':
        strategy = pl.strategies.DDPSpawnStrategy(find_unused_parameters=False)
    else:
        strategy = None

    trainer = pl.Trainer(strategy=strategy,
                         accelerator='gpu',
                         devices=devices,
                         max_epochs=1,
                         log_every_n_steps=1)
    datamodule = LightningDataset(train_dataset, batch_size=5, num_workers=3)
    assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
                               'batch_size=5, num_workers=3, '
                               'pin_memory=True, persistent_workers=True)')
    trainer.fit(model, datamodule)
    if strategy_type is None:
        assert not trainer._data_connector._val_dataloader_source.is_defined()
        assert not trainer._data_connector._test_dataloader_source.is_defined()
Ejemplo n.º 2
0
def test_lightning_dataset(strategy):
    import pytorch_lightning as pl

    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = TUDataset(root, name='MUTAG').shuffle()
    train_dataset = dataset[:50]
    val_dataset = dataset[50:80]
    test_dataset = dataset[80:90]
    shutil.rmtree(root)

    gpus = 1 if strategy is None else torch.cuda.device_count()
    if strategy == 'ddp_spawn':
        strategy = pl.plugins.DDPSpawnPlugin(find_unused_parameters=False)

    model = LinearGraphModule(dataset.num_features, 64, dataset.num_classes)

    trainer = pl.Trainer(strategy=strategy,
                         gpus=gpus,
                         max_epochs=1,
                         log_every_n_steps=1)
    datamodule = LightningDataset(train_dataset,
                                  val_dataset,
                                  test_dataset,
                                  batch_size=5,
                                  num_workers=3)
    old_x = train_dataset.data.x.clone()
    assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
                               'val_dataset=MUTAG(30), '
                               'test_dataset=MUTAG(10), batch_size=5, '
                               'num_workers=3, pin_memory=True, '
                               'persistent_workers=True)')
    trainer.fit(model, datamodule)
    new_x = train_dataset.data.x
    offset = 10 + 6 + 2 * gpus  # `train_steps` + `val_steps` + `sanity`
    assert torch.all(new_x > (old_x + offset - 4))  # Ensure shared data.
    assert trainer._data_connector._val_dataloader_source.is_defined()
    assert trainer._data_connector._test_dataloader_source.is_defined()

    # Test with `val_dataset=None` and `test_dataset=None`:
    warnings.filterwarnings('ignore', '.*Skipping val loop.*')
    trainer = pl.Trainer(strategy=strategy,
                         gpus=gpus,
                         max_epochs=1,
                         log_every_n_steps=1)
    datamodule = LightningDataset(train_dataset, batch_size=5, num_workers=3)
    assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
                               'batch_size=5, num_workers=3, '
                               'pin_memory=True, persistent_workers=True)')
    trainer.fit(model, datamodule)
    assert not trainer._data_connector._val_dataloader_source.is_defined()
    assert not trainer._data_connector._test_dataloader_source.is_defined()
Ejemplo n.º 3
0
def main():
    seed_everything(42)

    root = osp.join('data', 'TUDataset')
    dataset = TUDataset(root, 'IMDB-BINARY', pre_transform=T.OneHotDegree(135))

    dataset = dataset.shuffle()
    test_dataset = dataset[:len(dataset) // 10]
    val_dataset = dataset[len(dataset) // 10:2 * len(dataset) // 10]
    train_dataset = dataset[2 * len(dataset) // 10:]

    datamodule = LightningDataset(train_dataset,
                                  val_dataset,
                                  test_dataset,
                                  batch_size=64,
                                  num_workers=4)

    model = Model(dataset.num_node_features, dataset.num_classes)

    devices = torch.cuda.device_count()
    strategy = pl.strategies.DDPSpawnStrategy(find_unused_parameters=False)
    checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_acc', save_top_k=1)
    trainer = pl.Trainer(strategy=strategy,
                         accelerator='gpu',
                         devices=devices,
                         max_epochs=50,
                         log_every_n_steps=5,
                         callbacks=[checkpoint])

    trainer.fit(model, datamodule)
    trainer.test(ckpt_path='best', datamodule=datamodule)