def main():
    seed_everything(42)

    root = osp.join('data', 'TUDataset')
    dataset = TUDataset(root, 'IMDB-BINARY', pre_transform=T.OneHotDegree(135))

    dataset = dataset.shuffle()
    test_dataset = dataset[:len(dataset) // 10]
    val_dataset = dataset[len(dataset) // 10:2 * len(dataset) // 10]
    train_dataset = dataset[2 * len(dataset) // 10:]

    datamodule = LightningDataset(train_dataset,
                                  val_dataset,
                                  test_dataset,
                                  batch_size=64,
                                  num_workers=4)

    model = Model(dataset.num_node_features, dataset.num_classes)

    devices = torch.cuda.device_count()
    #ddp = DDPStrategy(process_group_backend='nccl')

    checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_acc', save_top_k=1)
    trainer = pl.Trainer(gpus=2,
                         plugins=DDPSpawnPlugin(find_unused_parameters=False),
                         max_epochs=5000,
                         log_every_n_steps=5,
                         callbacks=[checkpoint])

    trainer.fit(model, datamodule)
    trainer.test(ckpt_path='best', datamodule=datamodule)
Example #2
0
def test_sync_batchnorm_ddp(tmpdir):
    seed_everything(234)
    set_random_master_port()

    # define datamodule and dataloader
    dm = MNISTDataModule()
    dm.prepare_data()
    dm.setup(stage=None)

    train_dataloader = dm.train_dataloader()
    model = SyncBNModule()

    bn_outputs = []

    # shuffle is false by default
    for batch_idx, batch in enumerate(train_dataloader):
        x, _ = batch

        _, out_bn = model.forward(x, batch_idx)
        bn_outputs.append(out_bn)

        # get 3 steps
        if batch_idx == 2:
            break

    bn_outputs = [x.cuda() for x in bn_outputs]

    # reset datamodule
    # batch-size = 16 because 2 GPUs in DDP
    dm = MNISTDataModule(batch_size=16, dist_sampler=True)
    dm.prepare_data()
    dm.setup(stage=None)

    model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs)
    ddp = DDPSpawnPlugin(
        parallel_devices=[torch.device("cuda", 0),
                          torch.device("cuda", 1)],
        num_nodes=1,
        sync_batchnorm=True,
        cluster_environment=LightningEnvironment(),
        find_unused_parameters=True,
    )

    trainer = Trainer(
        default_root_dir=tmpdir,
        gpus=2,
        num_nodes=1,
        accelerator="ddp_spawn",
        max_epochs=1,
        max_steps=3,
        sync_batchnorm=True,
        num_sanity_val_steps=0,
        replace_sampler_ddp=False,
        plugins=[ddp],
    )

    trainer.fit(model, dm)
    assert trainer.state.finished, "Sync batchnorm failing with DDP"
def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
    """Test for DDP Spawn FP16 compress hook."""
    model = BoringModel()
    training_type_plugin = DDPSpawnPlugin(ddp_comm_hook=default.fp16_compress_hook, sync_batchnorm=True)
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        strategy=training_type_plugin,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    assert trainer.state.finished, f"Training failed with {trainer.state}"