def main(): seed_everything(42) root = osp.join('data', 'TUDataset') dataset = TUDataset(root, 'IMDB-BINARY', pre_transform=T.OneHotDegree(135)) dataset = dataset.shuffle() test_dataset = dataset[:len(dataset) // 10] val_dataset = dataset[len(dataset) // 10:2 * len(dataset) // 10] train_dataset = dataset[2 * len(dataset) // 10:] datamodule = LightningDataset(train_dataset, val_dataset, test_dataset, batch_size=64, num_workers=4) model = Model(dataset.num_node_features, dataset.num_classes) devices = torch.cuda.device_count() #ddp = DDPStrategy(process_group_backend='nccl') checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_acc', save_top_k=1) trainer = pl.Trainer(gpus=2, plugins=DDPSpawnPlugin(find_unused_parameters=False), max_epochs=5000, log_every_n_steps=5, callbacks=[checkpoint]) trainer.fit(model, datamodule) trainer.test(ckpt_path='best', datamodule=datamodule)
def test_sync_batchnorm_ddp(tmpdir): seed_everything(234) set_random_master_port() # define datamodule and dataloader dm = MNISTDataModule() dm.prepare_data() dm.setup(stage=None) train_dataloader = dm.train_dataloader() model = SyncBNModule() bn_outputs = [] # shuffle is false by default for batch_idx, batch in enumerate(train_dataloader): x, _ = batch _, out_bn = model.forward(x, batch_idx) bn_outputs.append(out_bn) # get 3 steps if batch_idx == 2: break bn_outputs = [x.cuda() for x in bn_outputs] # reset datamodule # batch-size = 16 because 2 GPUs in DDP dm = MNISTDataModule(batch_size=16, dist_sampler=True) dm.prepare_data() dm.setup(stage=None) model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs) ddp = DDPSpawnPlugin( parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)], num_nodes=1, sync_batchnorm=True, cluster_environment=LightningEnvironment(), find_unused_parameters=True, ) trainer = Trainer( default_root_dir=tmpdir, gpus=2, num_nodes=1, accelerator="ddp_spawn", max_epochs=1, max_steps=3, sync_batchnorm=True, num_sanity_val_steps=0, replace_sampler_ddp=False, plugins=[ddp], ) trainer.fit(model, dm) assert trainer.state.finished, "Sync batchnorm failing with DDP"
def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): """Test for DDP Spawn FP16 compress hook.""" model = BoringModel() training_type_plugin = DDPSpawnPlugin(ddp_comm_hook=default.fp16_compress_hook, sync_batchnorm=True) trainer = Trainer( max_epochs=1, gpus=2, strategy=training_type_plugin, default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, ) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}"