def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", tmpdir) model = BoringModel() # Extra layer model.c_d3 = torch.nn.Linear(32, 32) # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, logger=logger, callbacks=[ModelCheckpoint(dirpath=tmpdir)], ) trainer.fit(model) # traning complete assert trainer.state.finished, f"Training failed with {trainer.state}" # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") trainer.save_checkpoint(new_weights_path) # load new model hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), "hparams.yaml") hparams_url = f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}" ckpt_path = hparams_url if url_ckpt else new_weights_path BoringModel.load_from_checkpoint(checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=False) with pytest.raises( RuntimeError, match= r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'): BoringModel.load_from_checkpoint(checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=True)
def _test_loggers_save_dir_and_weights_save_path(tmpdir, logger_class): class TestLogger(logger_class): # for this test it does not matter what these attributes are # so we standardize them to make testing easier @property def version(self): return "version" @property def name(self): return "name" model = BoringModel() trainer_args = dict(default_root_dir=tmpdir, max_steps=3) # no weights_save_path given save_dir = tmpdir / "logs" weights_save_path = None logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) trainer.fit(model) assert trainer._weights_save_path_internal == trainer.default_root_dir assert trainer.checkpoint_callback.dirpath == os.path.join(str(logger.save_dir), "name", "version", "checkpoints") assert trainer.default_root_dir == tmpdir # with weights_save_path given, the logger path and checkpoint path should be different save_dir = tmpdir / "logs" weights_save_path = tmpdir / "weights" logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) with pytest.deprecated_call(match=r"Setting `Trainer\(weights_save_path=\)` has been deprecated in v1.6"): trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) trainer.fit(model) assert trainer._weights_save_path_internal == weights_save_path assert trainer.logger.save_dir == save_dir assert trainer.checkpoint_callback.dirpath == weights_save_path / "name" / "version" / "checkpoints" assert trainer.default_root_dir == tmpdir # no logger given weights_save_path = tmpdir / "weights" with pytest.deprecated_call(match=r"Setting `Trainer\(weights_save_path=\)` has been deprecated in v1.6"): trainer = Trainer(**trainer_args, logger=False, weights_save_path=weights_save_path) trainer.fit(model) assert trainer._weights_save_path_internal == weights_save_path assert trainer.checkpoint_callback.dirpath == weights_save_path / "checkpoints" assert trainer.default_root_dir == tmpdir
def test_correct_step_and_epoch(tmpdir): model = BoringModel() first_max_epochs = 2 train_batches = 2 trainer = Trainer(default_root_dir=tmpdir, max_epochs=first_max_epochs, limit_train_batches=train_batches, limit_val_batches=0) assert trainer.current_epoch == 0 assert trainer.global_step == 0 trainer.fit(model) # TODO(@carmocca): should not need `-1` assert trainer.current_epoch == first_max_epochs - 1 assert trainer.global_step == first_max_epochs * train_batches # save checkpoint after loop ends, training end called, epoch count increased ckpt_path = str(tmpdir / "model.ckpt") trainer.save_checkpoint(ckpt_path) ckpt = torch.load(ckpt_path) assert ckpt["epoch"] == first_max_epochs # TODO(@carmocca): should not need `+1` assert ckpt["global_step"] == first_max_epochs * train_batches + 1 max_epochs = first_max_epochs + 2 trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, limit_train_batches=train_batches, limit_val_batches=0) # the ckpt state is not loaded at this point assert trainer.current_epoch == 0 assert trainer.global_step == 0 class TestModel(BoringModel): def on_pretrain_routine_end(self) -> None: assert self.trainer.current_epoch == first_max_epochs # TODO(@carmocca): should not need `+1` assert self.trainer.global_step == first_max_epochs * train_batches + 1 trainer.fit(TestModel(), ckpt_path=ckpt_path) # TODO(@carmocca): should not need `-1` assert trainer.current_epoch == max_epochs - 1 # TODO(@carmocca): should not need `+1` assert trainer.global_step == max_epochs * train_batches + 1
def test_v1_5_0_old_on_test_epoch_end(tmpdir): callback_warning_cache.clear() class OldSignature(Callback): def on_test_epoch_end(self, trainer, pl_module): # noqa ... model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.test(model) class OldSignatureModel(BoringModel): def on_test_epoch_end(self): # noqa ... model = OldSignatureModel() with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.test(model) callback_warning_cache.clear() class NewSignature(Callback): def on_test_epoch_end(self, trainer, pl_module, outputs): ... trainer.callbacks = [NewSignature()] with no_deprecated_call( match="`Callback.on_test_epoch_end` signature has changed in v1.3." ): trainer.test(model) class NewSignatureModel(BoringModel): def on_test_epoch_end(self, outputs): ... model = NewSignatureModel() with no_deprecated_call( match= "`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): trainer.test(model)
def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_steps): """Tests that max_steps can be used without max_epochs.""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, min_epochs=min_epochs, max_epochs=max_epochs, min_steps=min_steps, max_steps=max_steps, enable_model_summary=False, ) trainer.fit(model) # check training stopped at max_epochs or max_steps if trainer.max_steps and not trainer.max_epochs: assert trainer.global_step == trainer.max_steps
def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): """Test for DDP Spawn FP16 compress hook.""" model = BoringModel() training_type_plugin = DDPSpawnPlugin( ddp_comm_hook=default.fp16_compress_hook, sync_batchnorm=True, ) trainer = Trainer( max_epochs=1, gpus=2, plugins=[training_type_plugin], default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, ) trainer.fit(model) assert (trainer.state == TrainerState.FINISHED ), f"Training failed with {trainer.state}"
def test_profiler_teardown(tmpdir, cls): """ This test checks if profiler teardown method is called when trainer is exiting. """ class TestCallback(Callback): def on_fit_end(self, trainer, pl_module) -> None: assert trainer.profiler.output_file is not None profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt")) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) trainer.fit(model) assert profiler.output_file is None
def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): """Test that all callback states get saved even if the ModelCheckpoint is not given as last.""" callback0 = StatefulCallback0() callback1 = StatefulCallback1() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states") model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_steps=1, limit_val_batches=1, callbacks=[callback0, checkpoint_callback, callback1] ) trainer.fit(model) ckpt = torch.load(str(tmpdir / "all_states.ckpt")) state0 = ckpt["callbacks"]["StatefulCallback0"] state1 = ckpt["callbacks"]["StatefulCallback1"] assert "content0" in state0 and state0["content0"] == 0 assert "content1" in state1 and state1["content1"] == 1 assert "ModelCheckpoint" in ckpt["callbacks"]
def test_ddp_fp16_compress_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() training_type_plugin = DDPPlugin(ddp_comm_hook=default.fp16_compress_hook) trainer = Trainer( max_epochs=1, gpus=2, strategy=training_type_plugin, default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, ) trainer.fit(model) trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data( ).comm_hook expected_comm_hook = default.fp16_compress_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_model_properties_resume_from_checkpoint(tmpdir): """Test that properties like `current_epoch` and `global_step` in model and trainer are always the same.""" model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, logger=False, callbacks=[checkpoint_callback, ModelTrainerPropertyParity()], # this performs the assertions ) trainer = Trainer(**trainer_args) trainer.fit(model) trainer_args.update(max_epochs=2) trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt")) trainer.fit(model)
def test_profiler_teardown(tmpdir, cls): """ This test checks if profiler teardown method is called when trainer is exiting. """ class TestCallback(Callback): def on_fit_end(self, trainer, *args, **kwargs) -> None: # describe sets it to None assert trainer.profiler._output_file is None profiler = cls(dirpath=tmpdir, filename="profiler") model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) trainer.fit(model) assert profiler._output_file is None
def test_timer_zero_duration_stop(tmpdir, interval): """ Test that the timer stops training immediately after the first check occurs. """ model = BoringModel() duration = timedelta(0) timer = Timer(duration=duration, interval=interval) trainer = Trainer( default_root_dir=tmpdir, callbacks=[timer], ) trainer.fit(model) if interval == "step": # timer triggers stop on step end assert trainer.global_step == 1 assert trainer.current_epoch == 0 else: # timer triggers stop on epoch end assert trainer.global_step == len(trainer.train_dataloader) assert trainer.current_epoch == 0
def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, test_deltas): """ Test that test progress updates with the correct amount. """ model = BoringModel() progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_test_batches=test_batches, callbacks=[progress_bar], logger=False, checkpoint_callback=False, ) trainer.test(model) progress_bar.test_progress_bar.update.assert_has_calls( [call(delta) for delta in test_deltas])
def test_v1_7_0_old_on_train_batch_end(tmpdir): class OldSignature(Callback): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): ... class OldSignatureModel(BoringModel): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): ... model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True) with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): trainer.fit(model) model = OldSignatureModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True) with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): trainer.fit(model)
def test_resume_callback_state_saved_by_type(tmpdir): """Test that a legacy checkpoint that didn't use a state key before can still be loaded.""" model = BoringModel() callback = OldStatefulCallback(state=111) trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) trainer.fit(model) ckpt_path = Path(trainer.checkpoint_callback.best_model_path) assert ckpt_path.exists() callback = OldStatefulCallback(state=222) trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) trainer.fit(model) assert callback.state == 111
def test_tpu_clip_grad_by_value(tmpdir): """Test if clip_gradients by value works on TPU. (It should not.)""" tutils.reset_seed() trainer_options = dict(default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=4, tpu_cores=1, limit_train_batches=10, limit_val_batches=10, gradient_clip_val=0.5, gradient_clip_algorithm='value') model = BoringModel() with pytest.raises(AssertionError): tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
def test_model_16bit_tpu_index(tmpdir, tpu_core): """Make sure model trains on TPU.""" tutils.reset_seed() trainer_options = dict( default_root_dir=tmpdir, precision=16, progress_bar_refresh_rate=0, max_epochs=2, tpu_cores=[tpu_core], limit_train_batches=4, limit_val_batches=2, ) model = BoringModel() tpipes.run_model_test(trainer_options, model, on_gpu=False) assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}' assert os.environ.get('XLA_USE_BF16') == str( 1), "XLA_USE_BF16 was not set in environment variables"
def test_device_stats_monitor_no_logger(tmpdir): """Test DeviceStatsMonitor with no logger in Trainer.""" model = BoringModel() device_stats = DeviceStatsMonitor() trainer = Trainer( default_root_dir=tmpdir, callbacks=[device_stats], max_epochs=1, logger=False, enable_checkpointing=False, enable_progress_bar=False, ) with pytest.raises(MisconfigurationException, match="Trainer that has no logger."): trainer.fit(model)
def test_gpu_stats_monitor_no_logger(tmpdir): """ Test GPUStatsMonitor with no logger in Trainer. """ model = BoringModel() gpu_stats = GPUStatsMonitor() trainer = Trainer( default_root_dir=tmpdir, callbacks=[gpu_stats], max_epochs=1, gpus=1, logger=False, ) with pytest.raises(MisconfigurationException, match='Trainer that has no logger.'): trainer.fit(model)
def test_simple_profiler_with_nonexisting_log_dir(tmpdir): """Ensure the profiler dirpath defaults to `trainer.log_dir`and creates it when not present.""" nonexisting_tmpdir = tmpdir / "nonexisting" profiler = SimpleProfiler(filename="profiler") assert profiler.dirpath is None model = BoringModel() trainer = Trainer( default_root_dir=nonexisting_tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, profiler=profiler ) trainer.fit(model) expected = nonexisting_tmpdir / "lightning_logs" / "version_0" assert expected.exists() assert trainer.log_dir == expected assert profiler.dirpath == trainer.log_dir assert expected.join("fit-profiler.txt").exists()
def test_pytorch_profiler_trainer_validate(tmpdir): """Ensure that the profiler can be given to the trainer and validate function are properly recorded. """ pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=2, profiler=pytorch_profiler, ) trainer.validate(model) assert sum(e.name == 'validation_step' for e in pytorch_profiler.function_events) path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt" assert path.read_text("utf-8")
def test_lr_monitor_single_lr(tmpdir): """Test that learning rates are extracted and logged for single lr scheduler.""" tutils.reset_seed() model = BoringModel() lr_monitor = LearningRateMonitor() trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]) trainer.fit(model) assert lr_monitor.lrs, "No learning rates logged" assert all(v is None for v in lr_monitor.last_momentum_values.values() ), "Momentum should not be logged by default" assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-SGD"]
def test_hpc_max_ckpt_version(tmpdir): """ Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version. """ model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_steps=1, ) trainer.fit(model) trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt") trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt") trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt") trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt") assert trainer.checkpoint_connector.hpc_resume_path == str( tmpdir / "hpc_ckpt_33.ckpt") assert trainer.checkpoint_connector.max_ckpt_version_in_folder( tmpdir) == 33 assert trainer.checkpoint_connector.max_ckpt_version_in_folder( tmpdir / "not" / "existing") is None
def test_amp_without_apex(tmpdir): """Check that even with apex amp type without requesting precision=16 the amp backend is void.""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, amp_backend='native', ) assert trainer.amp_backend is None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, amp_backend='apex', ) assert trainer.amp_backend is None trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.dev_debugger.count_events('AMP') == 0
def test_eval_distributed_sampler_warning(tmpdir): """Test that a warning is raised when `DistributedSampler` is used with evaluation.""" model = BoringModel() trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True) trainer._data_connector.attach_data(model) trainer.state.fn = TrainerFn.VALIDATING with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): trainer.reset_val_dataloader(model) trainer.state.fn = TrainerFn.TESTING with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): trainer.reset_test_dataloader(model)
def test_multi_gpu_model_ddp_spawn(tmpdir): tutils.set_random_main_port() trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=10, limit_val_batches=10, gpus=[0, 1], strategy="ddp_spawn", enable_progress_bar=False, ) model = BoringModel() tpipes.run_model_test(trainer_options, model) # test memory helper functions memory.get_memory_profile("min_max")
def test_accumulated_gradient_batches_with_resume_from_checkpoint(tmpdir): """ This test validates that accumulated gradient is properly recomputed and reset on the trainer. """ ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = BoringModel() trainer_kwargs = dict(max_epochs=1, accumulate_grad_batches={0: 2}, callbacks=ckpt, limit_train_batches=1, limit_val_batches=0) trainer = Trainer(**trainer_kwargs) trainer.fit(model) trainer_kwargs['max_epochs'] = 2 trainer_kwargs['resume_from_checkpoint'] = ckpt.last_model_path trainer = Trainer(**trainer_kwargs) trainer.fit(model)
def test_multi_gpu_model_dp(tmpdir): tutils.set_random_master_port() trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=10, limit_val_batches=10, gpus=[0, 1], accelerator='dp', progress_bar_refresh_rate=0, ) model = BoringModel() tpipes.run_model_test(trainer_options, model) # test memory helper functions memory.get_memory_profile('min_max')
def test_amp_without_apex(bwd_mock, tmpdir): """Check that even with apex amp type without requesting precision=16 the amp backend is void.""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, amp_backend='native', ) assert trainer.amp_backend is None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, amp_backend='apex', ) assert trainer.amp_backend is None trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" assert not bwd_mock.called
def test_prediction_writer_hook_call_intervals(tmpdir): """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined interval.""" DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() dataloader = DataLoader(RandomDataset(32, 64)) model = BoringModel() cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) results = trainer.predict(model, dataloaders=dataloader) assert len(results) == 4 assert cb.write_on_batch_end.call_count == 4 assert cb.write_on_epoch_end.call_count == 1 DummyPredictionWriter.write_on_batch_end.reset_mock() DummyPredictionWriter.write_on_epoch_end.reset_mock() cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) trainer.predict(model, dataloaders=dataloader, return_predictions=False) assert cb.write_on_batch_end.call_count == 4 assert cb.write_on_epoch_end.call_count == 1 DummyPredictionWriter.write_on_batch_end.reset_mock() DummyPredictionWriter.write_on_epoch_end.reset_mock() cb = DummyPredictionWriter("batch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) trainer.predict(model, dataloaders=dataloader, return_predictions=False) assert cb.write_on_batch_end.call_count == 4 assert cb.write_on_epoch_end.call_count == 0 DummyPredictionWriter.write_on_batch_end.reset_mock() DummyPredictionWriter.write_on_epoch_end.reset_mock() cb = DummyPredictionWriter("epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) trainer.predict(model, dataloaders=dataloader, return_predictions=False) assert cb.write_on_batch_end.call_count == 0 assert cb.write_on_epoch_end.call_count == 1