def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams): """ Test that new batch size gets written to the correct hyperparameter attribute. """ tutils.reset_seed() hparams = EvalModelTemplate.get_default_hparams() before_batch_size = hparams.get('batch_size') class HparamsEvalModelTemplate(EvalModelTemplate): def dataloader(self, *args, **kwargs): # artificially set batch_size so we can get a dataloader # remove it immediately after, because we want only self.hparams.batch_size setattr(self, "batch_size", before_batch_size) dataloader = super().dataloader(*args, **kwargs) del self.batch_size return dataloader datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored! datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate model = model_class(**hparams) model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) trainer.tune(model, datamodule_fit) after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size assert trainer.datamodule == datamodule_fit assert before_batch_size != after_batch_size assert after_batch_size <= len(trainer.train_dataloader.dataset) assert datamodule_fit.batch_size == after_batch_size # should be left unchanged, since it was not passed to .tune() assert datamodule_model.batch_size == 111
def test_sync_batchnorm_ddp(tmpdir): seed_everything(234) set_random_master_port() # define datamodule and dataloader dm = MNISTDataModule() dm.prepare_data() dm.setup(stage=None) train_dataloader = dm.train_dataloader() model = SyncBNModule() bn_outputs = [] # shuffle is false by default for batch_idx, batch in enumerate(train_dataloader): x, _ = batch _, out_bn = model.forward(x, batch_idx) bn_outputs.append(out_bn) # get 3 steps if batch_idx == 2: break bn_outputs = [x.cuda() for x in bn_outputs] # reset datamodule # batch-size = 16 because 2 GPUs in DDP dm = MNISTDataModule(batch_size=16, dist_sampler=True) dm.prepare_data() dm.setup(stage=None) model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs) trainer = Trainer( gpus=2, num_nodes=1, distributed_backend='ddp_spawn', max_epochs=1, max_steps=3, sync_batchnorm=True, num_sanity_val_steps=0, replace_sampler_ddp=False, ) result = trainer.fit(model, dm) assert result == 1, "Sync batchnorm failing with DDP"