def test_running_test_without_val(tmpdir): """Verify `test()` works on a model with no `val_loader`.""" tutils.reset_seed() class CurrentTestModel(LightTrainDataloader, LightTestMixin, TestModelBase): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta logger = tutils.get_default_testtube_logger(tmpdir, False) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, test_percent_check=0.2, checkpoint_callback=checkpoint, logger=logger, early_stop_callback=False ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1, 'training failed to complete' trainer.test() # test we have good test accuracy tutils.assert_ok_model_acc(trainer)
def test_optimizer_with_scheduling(tmpdir): """ Verify that learning rate scheduling is working """ tutils.reset_seed() class CurrentTestModel( LightTestOptimizerWithSchedulingMixin, LightTrainDataloader, TestModelBase): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # fit model trainer = Trainer(**trainer_options) results = trainer.fit(model) init_lr = hparams.learning_rate adjusted_lr = [pg['lr'] for pg in trainer.optimizers[0].param_groups] assert len(trainer.lr_schedulers) == 1, \ 'lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers) assert all(a == adjusted_lr[0] for a in adjusted_lr), \ 'Lr not equally adjusted for all param groups' adjusted_lr = adjusted_lr[0] assert init_lr * 0.1 == adjusted_lr, \ 'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr)
def test_mixing_of_dataloader_options(tmpdir): """Verify that dataloaders can be passed to fit""" tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightValStepFitSingleDataloaderMixin, LightTestFitSingleTestDataloadersMixin, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # fit model trainer = Trainer(**trainer_options) fit_options = dict(val_dataloaders=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) # fit model trainer = Trainer(**trainer_options) fit_options = dict(val_dataloaders=model._dataloader(train=False), test_dataloaders=model._dataloader(train=False)) _ = trainer.fit(model, **fit_options) trainer.test() assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_all_dataloaders_passed_to_fit(tmpdir): """Verify train, val & test dataloader can be passed to fit """ tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightValStepFitSingleDataloaderMixin, LightTestFitSingleTestDataloadersMixin, LightEmptyTestStep, TestModelBase, ): pass hparams = tutils.get_default_hparams() # logger file to get meta trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # train, val and test passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False)) test_options = dict(test_dataloaders=model._dataloader(train=False)) result = trainer.fit(model, **fit_options) trainer.test(**test_options) assert result == 1 assert len(trainer.val_dataloaders) == 1, \ f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_no_val_module(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" tutils.reset_seed() hparams = tutils.get_default_hparams() class CurrentTestModel(LightTrainDataloader, TestModelBase): pass model = CurrentTestModel(hparams) # logger file to get meta logger = tutils.get_default_logger(tmpdir) trainer = Trainer( max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir) ) # fit model result = trainer.fit(model) # training complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # assert ckpt has hparams ckpt = torch.load(new_weights_path) assert 'hparams' in ckpt.keys(), 'hparams missing from checkpoints' # won't load without hparams in the ckpt model_2 = LightningTestModel.load_from_checkpoint( checkpoint_path=new_weights_path, ) model_2.eval()
def test_amp_gpu_ddp_slurm_managed(tmpdir): """Make sure DDP + AMP work.""" tutils.reset_seed() # simulate setting slurm flags tutils.set_random_master_port() os.environ['SLURM_LOCALID'] = str(0) hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) # fit model trainer = Trainer( max_epochs=1, gpus=[0], distributed_backend='ddp', precision=16, checkpoint_callback=checkpoint, logger=logger, ) trainer.is_slurm_managing_tasks = True result = trainer.fit(model) # correct result and ok accuracy assert result == 1, 'amp + ddp model failed to complete' # test root model address assert trainer.resolve_root_node_address('abc') == 'abc' assert trainer.resolve_root_node_address('abc[23]') == 'abc23' assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23' assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23'
def test_running_test_pretrained_model(tmpdir): """Verify test() on pretrained model.""" tutils.reset_seed() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) # logger file to get meta logger = tutils.get_default_testtube_logger(tmpdir, False) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( show_progress_bar=False, max_epochs=4, train_percent_check=0.4, val_percent_check=0.2, checkpoint_callback=checkpoint, logger=logger ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = tutils.load_model( logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel ) new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) # test we have good test accuracy tutils.assert_ok_model_acc(new_trainer)
def test_warning_with_few_workers(tmpdir): """ Test that error is raised if dataloader with only a few workers is used """ tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightValStepFitSingleDataloaderMixin, LightTestFitSingleTestDataloadersMixin, LightEmptyTestStep, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False)) test_options = dict(test_dataloaders=model._dataloader(train=False)) trainer = Trainer(**trainer_options) # fit model with pytest.warns(UserWarning, match='train'): trainer.fit(model, **fit_options) with pytest.warns(UserWarning, match='val'): trainer.fit(model, **fit_options) with pytest.warns(UserWarning, match='test'): trainer.test(**test_options)
def test_multiple_test_dataloader(tmpdir): """Verify multiple test_dataloader.""" tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightTestMultipleDataloadersMixin, LightEmptyTestStep, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # fit model trainer = Trainer(**trainer_options) trainer.fit(model) trainer.test() # verify there are 2 val loaders assert len(trainer.test_dataloaders) == 2, \ 'Multiple test_dataloaders not initiated properly' # make sure predictions are good for each test set for dataloader in trainer.test_dataloaders: tutils.run_prediction(dataloader, trainer.model) # run the test method trainer.test()
def test_loading_meta_tags(tmpdir): """ test for backward compatibility to meta_tags.csv """ tutils.reset_seed() hparams = EvalModelTemplate.get_default_hparams() # save tags logger = tutils.get_default_logger(tmpdir) logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0)) logger.log_hyperparams(hparams) logger.save() # load hparams path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE) hparams = load_hparams_from_yaml(hparams_path) # save as legacy meta_tags.csv tags_path = os.path.join(path_expt_dir, 'meta_tags.csv') save_hparams_to_tags_csv(tags_path, hparams) tags = load_hparams_from_tags_csv(tags_path) assert hparams == tags
def test_no_val_end_module(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" tutils.reset_seed() class CurrentTestModel(LightTrainDataloader, LightValidationStepMixin, TestModelBase): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir) ) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') model_2 = LightningTestModel.load_from_checkpoint( checkpoint_path=new_weights_path, tags_csv=tags_path ) model_2.eval()
def test_lr_logger_multi_lrs(tmpdir): """ Test that learning rates are extracted and logged for multi lr schedulers """ tutils.reset_seed() class CurrentTestModel(LightTestOptimizersWithMixedSchedulingMixin, LightTrainDataloader, TestModelBase): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) lr_logger = LearningRateLogger() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.5, callbacks=[lr_logger]) results = trainer.fit(model) assert lr_logger.lrs, 'No learning rates logged' assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ 'Number of learning rates logged does not match number of lr schedulers' assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ 'Names of learning rates not set correctly'
def test_trains_logger(tmpdir): """Verify that basic functionality of TRAINS logger works.""" tutils.reset_seed() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) TrainsLogger.set_bypass_mode(True) TrainsLogger.set_credentials( api_host='http://integration.trains.allegro.ai:8008', files_host='http://integration.trains.allegro.ai:8081', web_host='http://integration.trains.allegro.ai:8080', ) logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test") trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger) result = trainer.fit(model) print('result finished') logger.finalize() assert result == 1, "Training failed"
def test_train_dataloaders_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ tutils.reset_seed() class CurrentTestModel(LightTrainDataloader, TestModelBase): pass hparams = tutils.get_default_hparams() # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # only train passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True)) result = trainer.fit(model, **fit_options) assert result == 1
def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightValidationMultipleDataloadersMixin, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=1.0, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # verify training completed assert result == 1 # verify there are 2 val loaders assert len(trainer.val_dataloaders) == 2, \ 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: tutils.run_prediction(dataloader, trainer.model)
def test_neptune_leave_open_experiment_after_fit(tmpdir): """Verify that neptune experiment was closed after training""" tutils.reset_seed() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) def _run_training(logger): logger._experiment = MagicMock() trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger) trainer = Trainer(**trainer_options) trainer.fit(model) return logger logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True)) assert logger_close_after_fit._experiment.stop.call_count == 1 logger_open_after_fit = _run_training( NeptuneLogger(offline_mode=True, close_after_fit=False)) assert logger_open_after_fit._experiment.stop.call_count == 0
def test_trainer_reset_correctly(tmpdir): ''' Check that all trainer parameters are reset correctly after lr_find() ''' tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer = Trainer( default_save_path=tmpdir, max_epochs=1 ) changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find', 'progress_bar_refresh_rate', 'accumulate_grad_batches', 'checkpoint_callback'] attributes_before = {} for ca in changed_attributes: attributes_before[ca] = getattr(trainer, ca) _ = trainer.lr_find(model, num_training=5) attributes_after = {} for ca in changed_attributes: attributes_after[ca] = getattr(trainer, ca) for key in changed_attributes: assert attributes_before[key] == attributes_after[key], \ f'Attribute {key} was not reset correctly after learning rate finder'
def test_trainer_arg_str(tmpdir): tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, TestModelBase, ): pass hparams = tutils.get_default_hparams() hparams.__dict__['my_fancy_lr'] = 1.0 # update with non-standard field model = CurrentTestModel(hparams) before_lr = hparams.my_fancy_lr # logger file to get meta trainer = Trainer( default_save_path=tmpdir, max_epochs=1, auto_lr_find='my_fancy_lr' ) trainer.fit(model) after_lr = model.hparams.my_fancy_lr assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder'
def test_wandb_logger(wandb): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" tutils.reset_seed() logger = WandbLogger(anonymous=True, offline=True) logger.log_metrics({'acc': 1.0}) wandb.init().log.assert_called_once_with({'acc': 1.0}) wandb.init().log.reset_mock() logger.log_metrics({'acc': 1.0}, step=3) wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0}) logger.log_hyperparams({'test': None}) wandb.init().config.update.assert_called_once_with({'test': None}) logger.watch('model', 'log', 10) wandb.watch.assert_called_once_with('model', log='log', log_freq=10) logger.finalize('fail') wandb.join.assert_called_once_with(1) wandb.join.reset_mock() logger.finalize('success') wandb.join.assert_called_once_with(0) wandb.join.reset_mock() wandb.join.side_effect = TypeError with pytest.raises(TypeError): logger.finalize('any') wandb.join.assert_called() assert logger.name == wandb.init().project_name() assert logger.version == wandb.init().id
def test_multiple_dataloaders_passed_to_fit(tmpdir): """Verify that multiple val & test dataloaders can be passed to fit.""" tutils.reset_seed() class CurrentTestModel( LightningTestModel, LightValStepFitMultipleDataloadersMixin, LightTestFitMultipleTestDataloadersMixin, ): pass hparams = tutils.get_default_hparams() # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # train, multiple val and multiple test passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)], test_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)]) results = trainer.fit(model, **fit_options) trainer.test() assert len(trainer.val_dataloaders) == 2, \ f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 2, \ f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_tbptt_cpu_model(tmpdir): """Test truncated back propagation through time works.""" tutils.reset_seed() truncated_bptt_steps = 2 sequence_size = 30 batch_size = 30 x_seq = torch.rand(batch_size, sequence_size, 1) y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() class MockSeq2SeqDataset(torch.utils.data.Dataset): def __getitem__(self, i): return x_seq, y_seq_list def __len__(self): return 1 class BpttTestModel(LightTrainDataloader, TestModelBase): def __init__(self, hparams): super().__init__(hparams) self.test_hidden = None def training_step(self, batch, batch_idx, hiddens): assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" self.test_hidden = torch.rand(1) x_tensor, y_list = batch assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) loss_val = torch.nn.functional.mse_loss( pred, y_tensor.view(batch_size, truncated_bptt_steps)) return { 'loss': loss_val, 'hiddens': self.test_hidden, } def train_dataloader(self): return torch.utils.data.DataLoader( dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None, ) trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, truncated_bptt_steps=truncated_bptt_steps, val_percent_check=0, weights_summary=None, early_stop_callback=False ) hparams = tutils.get_default_hparams() hparams.batch_size = batch_size hparams.in_features = truncated_bptt_steps hparams.hidden_dim = truncated_bptt_steps hparams.out_features = truncated_bptt_steps model = BpttTestModel(hparams) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1, 'training failed to complete'
def test_dataloader_config_errors(tmpdir): tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # percent check < 0 # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, train_percent_check=-0.1, ) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) # percent check > 1 # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, train_percent_check=1.1, ) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) # int val_check_interval > num batches # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_check_interval=10000) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) # float val_check_interval > 1 # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_check_interval=1.1) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model)
def test_cpu_slurm_save_load(tmpdir): """Verify model save/load/checkpoint on CPU.""" tutils.reset_seed() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) # logger file to get meta logger = tutils.get_default_testtube_logger(tmpdir, False) version = logger.version trainer_options = dict( max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir) ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) real_global_step = trainer.global_step # traning complete assert result == 1, 'amp + ddp model failed to complete' # predict with trained model before saving # make a prediction dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: for batch in dataloader: break x, y = batch x = x.view(x.size(0), -1) model.eval() pred_before_saving = model(x) # test HPC saving # simulate snapshot on slurm saved_filepath = trainer.hpc_save(tmpdir, logger) assert os.path.exists(saved_filepath) # new logger file to get meta logger = tutils.get_default_testtube_logger(tmpdir, False, version=version) trainer_options = dict( max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir), ) trainer = Trainer(**trainer_options) model = LightningTestModel(hparams) # set the epoch start hook so we can predict before the model does the full training def assert_pred_same(): assert trainer.global_step == real_global_step and trainer.global_step > 0 # predict with loaded model to make sure answers are the same trainer.model.eval() new_pred = trainer.model(x) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 model.on_epoch_start = assert_pred_same # by calling fit again, we trigger training, loading weights from the cluster # and our hook to predict using current model before any more weight updates trainer.fit(model)
def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" tutils.reset_seed() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) trainer_options = dict( max_epochs=1, gpus=2, distributed_backend='dp', ) # get logger logger = tutils.get_default_logger(tmpdir) # exp file to get weights # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) # add these to the trainer options trainer_options['logger'] = logger trainer_options['checkpoint_callback'] = checkpoint # fit model trainer = Trainer(**trainer_options) trainer.is_slurm_managing_tasks = True result = trainer.fit(model) # track epoch before saving. Increment since we finished the current epoch, don't want to rerun real_global_epoch = trainer.current_epoch + 1 # correct result and ok accuracy assert result == 1, 'amp + dp model failed to complete' # --------------------------- # HPC LOAD/SAVE # --------------------------- # save trainer.hpc_save(tmpdir, logger) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) trainer_options['logger'] = new_logger trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir) trainer_options['train_percent_check'] = 0.5 trainer_options['val_percent_check'] = 0.2 trainer_options['max_epochs'] = 1 new_trainer = Trainer(**trainer_options) # set the epoch start hook so we can predict before the model does the full training def assert_good_acc(): assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() dataloader = trainer.train_dataloader tutils.run_prediction(dataloader, dp_model, dp=True) # new model model = LightningTestModel(hparams) model.on_train_start = assert_good_acc # fit new model which should load hpc weights new_trainer.fit(model) # test freeze on gpu model.freeze() model.unfreeze()
def test_optimizer_return_options(): tutils.reset_seed() trainer = Trainer() model, hparams = tutils.get_default_model() # single optimizer opt_a = torch.optim.Adam(model.parameters(), lr=0.002) opt_b = torch.optim.SGD(model.parameters(), lr=0.002) scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10) scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10) # single optimizer model.configure_optimizers = lambda: opt_a optim, lr_sched, freq = trainer.init_optimizers(model) assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0 # opt tuple model.configure_optimizers = lambda: (opt_a, opt_b) optim, lr_sched, freq = trainer.init_optimizers(model) assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b assert len(lr_sched) == 0 and len(freq) == 0 # opt list model.configure_optimizers = lambda: [opt_a, opt_b] optim, lr_sched, freq = trainer.init_optimizers(model) assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b assert len(lr_sched) == 0 and len(freq) == 0 # opt tuple of 2 lists model.configure_optimizers = lambda: ([opt_a], [scheduler_a]) optim, lr_sched, freq = trainer.init_optimizers(model) assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 assert optim[0] == opt_a assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', frequency=1, reduce_on_plateau=False, monitor='val_loss') # opt single dictionary model.configure_optimizers = lambda: { "optimizer": opt_a, "lr_scheduler": scheduler_a } optim, lr_sched, freq = trainer.init_optimizers(model) assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 assert optim[0] == opt_a assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', frequency=1, reduce_on_plateau=False, monitor='val_loss') # opt multiple dictionaries with frequencies model.configure_optimizers = lambda: ( { "optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1 }, { "optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5 }, ) optim, lr_sched, freq = trainer.init_optimizers(model) assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2 assert optim[0] == opt_a assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', frequency=1, reduce_on_plateau=False, monitor='val_loss') assert freq == [1, 5]
def test_resume_from_checkpoint_epoch_restored(tmpdir): """Verify resuming from checkpoint runs the right number of epochs""" import types tutils.reset_seed() hparams = tutils.get_default_hparams() def _new_model(): # Create a model that tracks epochs and batches seen model = LightningTestModel(hparams) model.num_epochs_seen = 0 model.num_batches_seen = 0 def increment_epoch(self): self.num_epochs_seen += 1 def increment_batch(self, _): self.num_batches_seen += 1 # Bind the increment_epoch function on_epoch_end so that the # model keeps track of the number of epochs it has seen. model.on_epoch_end = types.MethodType(increment_epoch, model) model.on_batch_start = types.MethodType(increment_batch, model) return model model = _new_model() trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, train_percent_check=0.65, val_percent_check=1, checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), logger=False, default_save_path=tmpdir, early_stop_callback=False, val_check_interval=1., ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model) training_batches = trainer.num_training_batches assert model.num_epochs_seen == 2 assert model.num_batches_seen == training_batches * 2 # Other checkpoints can be uncommented if/when resuming mid-epoch is supported checkpoints = sorted( glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) for check in checkpoints: next_model = _new_model() state = torch.load(check) # Resume training trainer_options['max_epochs'] = 2 new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check) new_trainer.fit(next_model) assert state[ 'global_step'] + next_model.num_batches_seen == training_batches * trainer_options[ 'max_epochs']
def test_gradient_accumulation_scheduling(tmpdir): """ Test grad accumulation by the freq of optimizer updates """ tutils.reset_seed() # test incorrect configs with pytest.raises(IndexError): assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6}) assert Trainer(accumulate_grad_batches={-2: 3}) with pytest.raises(TypeError): assert Trainer(accumulate_grad_batches={}) assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]]) assert Trainer(accumulate_grad_batches={1: 2, 3.: 4}) assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5}) # test optimizer call freq matches scheduler def _optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): # only test the first 12 batches in epoch if batch_idx < 12: if epoch == 0: # reset counter when starting epoch if batch_idx == 0: self.prev_called_batch_idx = 0 # use this opportunity to test once assert self.trainer.accumulate_grad_batches == 1 assert batch_idx == self.prev_called_batch_idx self.prev_called_batch_idx += 1 elif 1 <= epoch <= 2: # reset counter when starting epoch if batch_idx == 1: self.prev_called_batch_idx = 1 # use this opportunity to test once assert self.trainer.accumulate_grad_batches == 2 assert batch_idx == self.prev_called_batch_idx self.prev_called_batch_idx += 2 else: if batch_idx == 3: self.prev_called_batch_idx = 3 # use this opportunity to test once assert self.trainer.accumulate_grad_batches == 4 assert batch_idx == self.prev_called_batch_idx self.prev_called_batch_idx += 3 optimizer.step() # clear gradients optimizer.zero_grad() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) schedule = {1: 2, 3: 4} trainer = Trainer(accumulate_grad_batches=schedule, train_percent_check=0.1, val_percent_check=0.1, max_epochs=2, default_save_path=tmpdir) # for the test trainer.optimizer_step = _optimizer_step model.prev_called_batch_idx = 0 trainer.fit(model)
def test_numpy_metric_ddp(): tutils.reset_seed() tutils.set_random_master_port() world_size = 2 mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)