def test_sync_batchnorm_ddp(tmpdir): seed_everything(234) set_random_main_port() # define datamodule and dataloader dm = MNISTDataModule() dm.prepare_data() dm.setup(stage=None) train_dataloader = dm.train_dataloader() model = SyncBNModule() bn_outputs = [] # shuffle is false by default for batch_idx, batch in enumerate(train_dataloader): x, _ = batch _, out_bn = model.forward(x, batch_idx) bn_outputs.append(out_bn) # get 3 steps if batch_idx == 2: break bn_outputs = [x.cuda() for x in bn_outputs] # reset datamodule # batch-size = 16 because 2 GPUs in DDP dm = MNISTDataModule(batch_size=16, dist_sampler=True) dm.prepare_data() dm.setup(stage=None) model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs) ddp = DDPSpawnStrategy( parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)], num_nodes=1, sync_batchnorm=True, cluster_environment=LightningEnvironment(), find_unused_parameters=True, ) trainer = Trainer( default_root_dir=tmpdir, gpus=2, num_nodes=1, strategy=ddp, max_epochs=1, max_steps=3, sync_batchnorm=True, num_sanity_val_steps=0, replace_sampler_ddp=False, ) trainer.fit(model, dm) # the strategy is responsible for tearing down the batchnorm wrappers assert not isinstance(model.bn_layer, torch.nn.modules.batchnorm.SyncBatchNorm) assert isinstance(model.bn_layer, torch.nn.modules.batchnorm._BatchNorm)
def test_sync_batchnorm_ddp(tmpdir): seed_everything(234) set_random_master_port() # define datamodule and dataloader dm = MNISTDataModule() dm.prepare_data() dm.setup(stage=None) train_dataloader = dm.train_dataloader() model = SyncBNModule() bn_outputs = [] # shuffle is false by default for batch_idx, batch in enumerate(train_dataloader): x, _ = batch _, out_bn = model.forward(x, batch_idx) bn_outputs.append(out_bn) # get 3 steps if batch_idx == 2: break bn_outputs = [x.cuda() for x in bn_outputs] # reset datamodule # batch-size = 16 because 2 GPUs in DDP dm = MNISTDataModule(batch_size=16, dist_sampler=True) dm.prepare_data() dm.setup(stage=None) model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs) ddp = DDPSpawnPlugin( parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)], num_nodes=1, sync_batchnorm=True, cluster_environment=LightningEnvironment(), find_unused_parameters=True, ) trainer = Trainer( default_root_dir=tmpdir, gpus=2, num_nodes=1, accelerator="ddp_spawn", max_epochs=1, max_steps=3, sync_batchnorm=True, num_sanity_val_steps=0, replace_sampler_ddp=False, plugins=[ddp], ) trainer.fit(model, dm) assert trainer.state.finished, "Sync batchnorm failing with DDP"
def test_sync_batchnorm_ddp(tmpdir): seed_everything(234) set_random_master_port() # define datamodule and dataloader dm = MNISTDataModule() dm.prepare_data() dm.setup(stage=None) train_dataloader = dm.train_dataloader() model = SyncBNModule() bn_outputs = [] # shuffle is false by default for batch_idx, batch in enumerate(train_dataloader): x, _ = batch _, out_bn = model.forward(x, batch_idx) bn_outputs.append(out_bn) # get 3 steps if batch_idx == 2: break bn_outputs = [x.cuda() for x in bn_outputs] # reset datamodule # batch-size = 16 because 2 GPUs in DDP dm = MNISTDataModule(batch_size=16, dist_sampler=True) dm.prepare_data() dm.setup(stage=None) model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs) trainer = Trainer(gpus=2, num_nodes=1, accelerator='ddp_spawn', max_epochs=1, max_steps=3, sync_batchnorm=True, num_sanity_val_steps=0, replace_sampler_ddp=False, plugins=[DDPPlugin(find_unused_parameters=True)]) trainer.fit(model, dm) assert trainer.state == TrainerState.FINISHED, "Sync batchnorm failing with DDP"