def test_running_test_after_fitting(tmpdir): """Verify test() on fitted model.""" tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) # logger file to get meta logger = tutils.get_test_tube_logger(tmpdir, False) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( default_save_path=tmpdir, show_progress_bar=False, max_epochs=4, train_percent_check=0.4, val_percent_check=0.2, test_percent_check=0.2, checkpoint_callback=checkpoint, logger=logger ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1, 'training failed to complete' trainer.test() # test we have good test accuracy tutils.assert_ok_model_acc(trainer)
def test_optimizer_return_options(): tutils.reset_seed() trainer = Trainer() model, hparams = tutils.get_model() # single optimizer opt_a = torch.optim.Adam(model.parameters(), lr=0.002) opt_b = torch.optim.SGD(model.parameters(), lr=0.002) optim, lr_sched = trainer.init_optimizers(opt_a) assert len(optim) == 1 and len(lr_sched) == 0 # opt tuple opts = (opt_a, opt_b) optim, lr_sched = trainer.init_optimizers(opts) assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1] assert len(lr_sched) == 0 # opt list opts = [opt_a, opt_b] optim, lr_sched = trainer.init_optimizers(opts) assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1] assert len(lr_sched) == 0 # opt tuple of lists opts = ([opt_a], ['lr_scheduler']) optim, lr_sched = trainer.init_optimizers(opts) assert len(optim) == 1 and len(lr_sched) == 1 assert optim[0] == opts[0][0] and lr_sched[0] == 'lr_scheduler'
def test_reduce_lr_on_plateau_scheduling(tmpdir): tutils.reset_seed() class CurrentTestModel( LightTestReduceLROnPlateauMixin, LightTrainDataloader, LightValidationMixin, LightValidationStepMixin, TestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # fit model trainer = Trainer(**trainer_options) results = trainer.fit(model) assert trainer.lr_schedulers[0] == \ dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss', interval='epoch', frequency=1, reduce_on_plateau=True), \ 'lr schduler was not correctly converted to dict'
def test_mlflow_logger(tmpdir): """Verify that basic functionality of mlflow logger works.""" tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) mlflow_dir = os.path.join(tmpdir, 'mlruns') logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}') # Test already exists logger2 = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}') _ = logger2.run_id # Try logging string logger.log_metrics({'acc': 'test'}) trainer_options = dict(default_save_path=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger) trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1, 'Training failed'
def test_optimizer_with_scheduling(tmpdir): """ Verify that learning rate scheduling is working """ tutils.reset_seed() class CurrentTestModel(LightTestOptimizerWithSchedulingMixin, LightTrainDataloader, TestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # fit model trainer = Trainer(**trainer_options) results = trainer.fit(model) init_lr = hparams.learning_rate adjusted_lr = [pg['lr'] for pg in trainer.optimizers[0].param_groups] assert len(trainer.lr_schedulers) == 1, \ 'lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers) assert all(a == adjusted_lr[0] for a in adjusted_lr), \ 'Lr not equally adjusted for all param groups' adjusted_lr = adjusted_lr[0] assert init_lr * 0.1 == adjusted_lr, \ 'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr)
def test_inf_test_dataloader(tmpdir): """Test inf test data loader (e.g. IterableDataset)""" tutils.reset_seed() class CurrentTestModel(LightInfTestDataloader, LightningTestModel, LightTestFitSingleTestDataloadersMixin): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # fit model with pytest.raises(MisconfigurationException): trainer = Trainer(default_save_path=tmpdir, max_epochs=1, test_percent_check=0.5) trainer.test(model) # logger file to get meta trainer = Trainer(default_save_path=tmpdir, max_epochs=1) result = trainer.fit(model) trainer.test(model) # verify training completed assert result == 1
def test_all_dataloaders_passed_to_fit(tmpdir): """ Verify train, val & test dataloader can be passed to fit """ tutils.reset_seed() class CurrentTestModel( LightningValStepFitSingleDataloaderMixin, LightningTestFitSingleTestDataloadersMixin, LightningTestModelBaseWithoutDataloader, ): pass hparams = tutils.get_hparams() # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # train, val and test passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False), test_dataloaders=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) trainer.test() assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_benchmark_option(tmpdir): """Verify benchmark option.""" tutils.reset_seed() class CurrentTestModel(LightValidationMultipleDataloadersMixin, LightTrainDataloader, TestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # verify torch.backends.cudnn.benchmark is not turned on assert not torch.backends.cudnn.benchmark # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, benchmark=True, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # verify training completed assert result == 1 # verify torch.backends.cudnn.benchmark is not turned off assert torch.backends.cudnn.benchmark
def test_running_test_pretrained_model(tmpdir): """Verify test() on pretrained model.""" tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) # logger file to get meta logger = tutils.get_test_tube_logger(tmpdir, False) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict(show_progress_bar=False, max_epochs=4, train_percent_check=0.4, val_percent_check=0.2, checkpoint_callback=checkpoint, logger=logger) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = tutils.load_model(logger, trainer.checkpoint_callback.filepath, module_class=LightningTestModel) new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) # test we have good test accuracy tutils.assert_ok_model_acc(new_trainer)
def test_multiple_dataloaders_passed_to_fit(tmpdir): """ Verify that multiple val & test dataloaders can be passed to fit """ tutils.reset_seed() class CurrentTestModel(LightningTestModelBaseWithoutDataloader): pass hparams = tutils.get_hparams() # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # train, multiple val and multiple test passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloader=[ model._dataloader(train=False), model._dataloader(train=False) ], test_dataloader=[ model._dataloader(train=False), model._dataloader(train=False) ]) results = trainer.fit(model, **fit_options) assert len(trainer.get_val_dataloaders()) == 2, \ f'Multiple `val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' assert len(trainer.get_test_dataloaders()) == 2, \ f'Multiple `test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'
def test_mixing_of_dataloader_options(tmpdir): """Verify that dataloaders can be passed to fit""" tutils.reset_seed() class CurrentTestModel(LightningTestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # fit model trainer = Trainer(**trainer_options) fit_options = dict(val_dataloader=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) # fit model trainer = Trainer(**trainer_options) fit_options = dict(val_dataloader=model._dataloader(train=False), test_dataloader=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) assert len(trainer.get_val_dataloaders()) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' assert len(trainer.get_test_dataloaders()) == 1, \ f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'
def test_comet_logger(tmpdir, monkeypatch): """Verify that basic functionality of Comet.ml logger works.""" # prevent comet logger from trying to print at exit, since # pytest's stdout/stderr redirection breaks it import atexit monkeypatch.setattr(atexit, "register", lambda _: None) tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) comet_dir = os.path.join(tmpdir, "cometruns") # We test CometLogger in offline mode with local saves logger = CometLogger( save_dir=comet_dir, project_name="general", workspace="dummy-test", ) trainer_options = dict(default_save_path=tmpdir, max_epochs=1, train_percent_check=0.01, logger=logger) trainer = Trainer(**trainer_options) result = trainer.fit(model) print('result finished') assert result == 1, "Training failed"
def test_neptune_leave_open_experiment_after_fit(tmpdir): """Verify that neptune experiment was closed after training""" tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) def _run_training(logger): logger._experiment = MagicMock() trainer_options = dict( default_save_path=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger ) trainer = Trainer(**trainer_options) trainer.fit(model) return logger logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True)) assert logger_close_after_fit._experiment.stop.call_count == 1 logger_open_after_fit = _run_training( NeptuneLogger(offline_mode=True, close_after_fit=False)) assert logger_open_after_fit._experiment.stop.call_count == 0
def test_wandb_pickle(tmpdir): """Verify that pickling trainer with wandb logger works.""" tutils.reset_seed() wandb_dir = str(tmpdir) logger = WandbLogger(save_dir=wandb_dir, anonymous=True) assert logger is not None
def test_comet_pickle(tmpdir, monkeypatch): """Verify that pickling trainer with comet logger works.""" # prevent comet logger from trying to print at exit, since # pytest's stdout/stderr redirection breaks it import atexit monkeypatch.setattr(atexit, "register", lambda _: None) tutils.reset_seed() # hparams = tutils.get_hparams() # model = LightningTestModel(hparams) comet_dir = os.path.join(tmpdir, "cometruns") # We test CometLogger in offline mode with local saves logger = CometLogger( save_dir=comet_dir, project_name="general", workspace="dummy-test", ) trainer_options = dict(default_save_path=tmpdir, max_epochs=1, logger=logger) trainer = Trainer(**trainer_options) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0})
def test_wandb_pickle(wandb): """Verify that pickling trainer with wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" tutils.reset_seed() class Experiment: id = 'the_id' wandb.init.return_value = Experiment() logger = WandbLogger(id='the_id', offline=True) trainer_options = dict(max_epochs=1, logger=logger) trainer = Trainer(**trainer_options) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) assert os.environ['WANDB_MODE'] == 'dryrun' assert trainer2.logger.__class__.__name__ == WandbLogger.__name__ _ = trainer2.logger.experiment wandb.init.assert_called() assert 'id' in wandb.init.call_args[1] assert wandb.init.call_args[1]['id'] == 'the_id' del os.environ['WANDB_MODE']
def test_no_val_end_module(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" tutils.reset_seed() class CurrentTestModel(LightningValidationStepMixin, LightningTestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta logger = tutils.get_test_tube_logger(tmpdir, False) trainer_options = dict(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir)) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') model_2 = LightningTestModel.load_from_metrics( weights_path=new_weights_path, tags_csv=tags_path) model_2.eval()
def test_multiple_test_dataloader(tmpdir): """Verify multiple test_dataloader.""" tutils.reset_seed() class CurrentTestModel(LightningTestMultipleDataloadersMixin, LightningTestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) trainer.test() # verify there are 2 val loaders assert len(trainer.test_dataloaders) == 2, \ 'Multiple test_dataloaders not initiated properly' # make sure predictions are good for each test set for dataloader in trainer.test_dataloaders: tutils.run_prediction(dataloader, trainer.model) # run the test method trainer.test()
def test_model_freeze_unfreeze(): tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) model.freeze() model.unfreeze()
def test_inf_train_dataloader(tmpdir): """Test inf train data loader (e.g. IterableDataset)""" tutils.reset_seed() class CurrentTestModel(LightningTestModel): def train_dataloader(self): dataloader = self._dataloader(train=True) class CustomInfDataLoader: def __init__(self, dataloader): self.dataloader = dataloader self.iter = iter(dataloader) self.count = 0 def __iter__(self): self.count = 0 return self def __next__(self): if self.count >= 5: raise StopIteration self.count = self.count + 1 try: return next(self.iter) except StopIteration: self.iter = iter(self.dataloader) return next(self.iter) return CustomInfDataLoader(dataloader) hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # fit model with pytest.raises(MisconfigurationException): trainer = Trainer( default_save_path=tmpdir, max_epochs=1, val_check_interval=0.5 ) trainer.fit(model) # logger file to get meta trainer = Trainer( default_save_path=tmpdir, max_epochs=1, val_check_interval=50, ) result = trainer.fit(model) # verify training completed assert result == 1
def test_model_saving_loading(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) # logger file to get meta logger = tutils.get_test_tube_logger(tmpdir, False) trainer_options = dict(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir)) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # make a prediction dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: for batch in dataloader: break x, y = batch x = x.view(x.size(0), -1) # generate preds before saving model model.eval() pred_before_saving = model(x) # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') model_2 = LightningTestModel.load_from_metrics( weights_path=new_weights_path, tags_csv=tags_path) model_2.eval() # make prediction # assert that both predictions are the same new_pred = model_2(x) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
def _init_steps_model(): """private method for initializing a model with 5% train epochs""" tutils.reset_seed() model, _ = tutils.get_model() # define train epoch to 5% of data train_percent = 0.05 # get number of samples in 1 epoch num_train_samples = math.floor( len(model.train_dataloader()) * train_percent) trainer_options = dict(train_percent_check=train_percent, ) return model, trainer_options, num_train_samples
def test_single_gpu_batch_parse(): tutils.reset_seed() if not tutils.can_run_gpu_test(): return trainer = Trainer() # batch is just a tensor batch = torch.rand(2, 3) batch = trainer.transfer_batch_to_gpu(batch, 0) assert batch.device.index == 0 and batch.type() == 'torch.cuda.FloatTensor' # tensor list batch = [torch.rand(2, 3), torch.rand(2, 3)] batch = trainer.transfer_batch_to_gpu(batch, 0) assert batch[0].device.index == 0 and batch[0].type( ) == 'torch.cuda.FloatTensor' assert batch[1].device.index == 0 and batch[1].type( ) == 'torch.cuda.FloatTensor' # tensor list of lists batch = [[torch.rand(2, 3), torch.rand(2, 3)]] batch = trainer.transfer_batch_to_gpu(batch, 0) assert batch[0][0].device.index == 0 and batch[0][0].type( ) == 'torch.cuda.FloatTensor' assert batch[0][1].device.index == 0 and batch[0][1].type( ) == 'torch.cuda.FloatTensor' # tensor dict batch = [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)}] batch = trainer.transfer_batch_to_gpu(batch, 0) assert batch[0]['a'].device.index == 0 and batch[0]['a'].type( ) == 'torch.cuda.FloatTensor' assert batch[0]['b'].device.index == 0 and batch[0]['b'].type( ) == 'torch.cuda.FloatTensor' # tuple of tensor list and list of tensor dict batch = ([torch.rand(2, 3) for _ in range(2)], [{ 'a': torch.rand(2, 3), 'b': torch.rand(2, 3) } for _ in range(2)]) batch = trainer.transfer_batch_to_gpu(batch, 0) assert batch[0][0].device.index == 0 and batch[0][0].type( ) == 'torch.cuda.FloatTensor' assert batch[1][0]['a'].device.index == 0 assert batch[1][0]['a'].type() == 'torch.cuda.FloatTensor' assert batch[1][0]['b'].device.index == 0 assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor'
def test_neptune_pickle(tmpdir): """Verify that pickling trainer with neptune logger works.""" tutils.reset_seed() logger = NeptuneLogger(offline_mode=True) trainer_options = dict(default_save_path=tmpdir, max_epochs=1, logger=logger) trainer = Trainer(**trainer_options) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({'acc': 1.0})
def test_cpu_model(tmpdir): """Make sure model trains on CPU.""" tutils.reset_seed() trainer_options = dict(default_save_path=tmpdir, show_progress_bar=False, logger=tutils.get_test_tube_logger(tmpdir), max_epochs=1, train_percent_check=0.4, val_percent_check=0.4) model, hparams = tutils.get_model() tutils.run_model_test(trainer_options, model, on_gpu=False)
def test_load_model_from_checkpoint(tmpdir): """Verify test() on pretrained model.""" tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) trainer_options = dict( show_progress_bar=False, max_epochs=2, train_percent_check=0.4, val_percent_check=0.2, checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), logger=False, default_save_path=tmpdir, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) trainer.test() # correct result and ok accuracy assert result == 1, 'training failed to complete' # load last checkpoint last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt") if not os.path.isfile(last_checkpoint): last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt") pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint) # test that hparams loaded correctly for k, v in vars(hparams).items(): assert getattr(pretrained_model.hparams, k) == v # assert weights are the same for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()): assert torch.all(torch.eq( old_p, new_p)), 'loaded weights are not the same as the saved weights' new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) # test we have good test accuracy tutils.assert_ok_model_acc(new_trainer)
def test_mlflow_pickle(tmpdir): """Verify that pickling trainer with mlflow logger works.""" tutils.reset_seed() mlflow_dir = os.path.join(tmpdir, 'mlruns') logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}') trainer_options = dict(default_save_path=tmpdir, max_epochs=1, logger=logger) trainer = Trainer(**trainer_options) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({'acc': 1.0})
def test_lbfgs_cpu_model(tmpdir): """Test each of the trainer options.""" tutils.reset_seed() trainer_options = dict( default_save_path=tmpdir, max_epochs=2, show_progress_bar=False, weights_summary='top', train_percent_check=1.0, val_percent_check=0.2, ) model, hparams = tutils.get_model(use_test_model=True, lbfgs=True) tutils.run_model_test_no_loggers(trainer_options, model, min_acc=0.30)
def test_running_test_pretrained_model_ddp(tmpdir): """Verify `test()` on pretrained model.""" if not tutils.can_run_gpu_test(): return tutils.reset_seed() tutils.set_random_master_port() hparams = tutils.get_hparams() model = LightningTestModel(hparams) # exp file to get meta logger = tutils.get_test_tube_logger(tmpdir, False) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict(show_progress_bar=False, max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, checkpoint_callback=checkpoint, logger=logger, gpus=[0, 1], distributed_backend='ddp') # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir))) # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = tutils.load_model(logger, trainer.checkpoint_callback.filepath, module_class=LightningTestModel) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tutils.run_prediction(dataloader, pretrained_model)
def test_dp_output_reduce(): mixin = TrainerLoggingMixin() tutils.reset_seed() # test identity when we have a single gpu out = torch.rand(3, 1) assert mixin.reduce_distributed_output(out, num_gpus=1) is out # average when we have multiples assert mixin.reduce_distributed_output(out, num_gpus=2) == out.mean() # when we have a dict of vals out = {'a': out, 'b': {'c': out}} reduced = mixin.reduce_distributed_output(out, num_gpus=3) assert reduced['a'] == out['a'] assert reduced['b']['c'] == out['b']['c']