def main(): seed_everything(4321) parser = ArgumentParser(add_help=False) parser = Trainer.add_argparse_args(parser) parser.add_argument("--trainer_method", default="fit") parser.add_argument("--tmpdir") parser.add_argument("--workdir") parser.set_defaults(gpus=2) parser.set_defaults(strategy="ddp") args = parser.parse_args() dm = ClassifDataModule() model = ClassificationModel() trainer = Trainer.from_argparse_args(args) if args.trainer_method == "fit": trainer.fit(model, datamodule=dm) result = None elif args.trainer_method == "test": result = trainer.test(model, datamodule=dm) elif args.trainer_method == "fit_test": trainer.fit(model, datamodule=dm) result = trainer.test(model, datamodule=dm) else: raise ValueError(f"Unsupported: {args.trainer_method}") result_ext = { "status": "complete", "method": args.trainer_method, "result": result } file_path = os.path.join(args.tmpdir, "ddp.result") torch.save(result_ext, file_path)
def test_callbacks_references_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint sets references as expected.""" dm = ClassifDataModule() model = ClassificationModel() args = { "default_root_dir": tmpdir, "max_steps": 1, "logger": False, "limit_val_batches": 2, "num_sanity_val_steps": 0, } # initial training checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer = Trainer(**args, callbacks=[checkpoint]) assert checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback trainer.fit(model, datamodule=dm) # resumed training new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) # pass in a new checkpoint object, which should take # precedence over the one in the last.ckpt file trainer = Trainer(**args, callbacks=[new_checkpoint]) assert checkpoint is not new_checkpoint assert new_checkpoint is trainer.callbacks[ -1] is trainer.checkpoint_callback trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))
def test_callbacks_state_resume_from_checkpoint(tmpdir): """Test that resuming from a checkpoint restores callbacks that persist state.""" dm = ClassifDataModule() model = ClassificationModel() callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( default_root_dir=tmpdir, max_steps=1, logger=False, callbacks=[checkpoint, callback_capture], limit_val_batches=2, ) assert checkpoint.best_model_path == "" assert checkpoint.best_model_score is None return trainer_args # initial training trainer = Trainer(**get_trainer_args()) trainer.fit(model, datamodule=dm) callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt")) trainer.fit(model, datamodule=dm) assert len(callbacks_before_resume) == len(callback_capture.callbacks) for before, after in zip(callbacks_before_resume, callback_capture.callbacks): if isinstance(before, ModelCheckpoint): assert before.best_model_path == after.best_model_path assert before.best_model_score == after.best_model_score
def test_datamodule_parameter(tmpdir): """Test that the datamodule parameter works.""" seed_everything(1) dm = ClassifDataModule() model = ClassificationModel() before_lr = model.lr # logger file to get meta trainer = Trainer(default_root_dir=tmpdir, max_epochs=2) lrfinder = trainer.tuner.lr_find(model, datamodule=dm) after_lr = lrfinder.suggestion() model.lr = after_lr assert before_lr != after_lr, "Learning rate was not altered after running learning rate finder"
def test_full_loop(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False, deterministic=True) # fit model trainer.fit(model, dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert dm.trainer is not None # validate result = trainer.validate(model, dm) assert dm.trainer is not None assert result[0]["val_acc"] > 0.7 # test result = trainer.test(model, dm) assert dm.trainer is not None assert result[0]["test_acc"] > 0.6
def test_full_loop(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, weights_summary=None, deterministic=True, ) # fit model result = trainer.fit(model, dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert dm.trainer is not None assert result # validate result = trainer.validate(datamodule=dm) assert dm.trainer is not None assert result[0]['val_acc'] > 0.7 # test result = trainer.test(datamodule=dm) assert dm.trainer is not None assert result[0]['test_acc'] > 0.6
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): """Verify `test()` on pretrained model.""" tutils.set_random_main_port() dm = ClassifDataModule() model = ClassificationModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( enable_progress_bar=False, max_epochs=2, limit_train_batches=2, limit_val_batches=2, callbacks=[checkpoint], logger=logger, accelerator="gpu", devices=[0, 1], strategy="ddp_spawn", default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir))) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1)
def test_train_val_loop_only(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() model.validation_step = None model.validation_step_end = None model.validation_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False) # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.callback_metrics["train_loss"] < 1.0
def test_early_stopping_no_val_step(tmpdir): """Test that early stopping callback falls back to training metrics when no validation defined.""" model = ClassificationModel() dm = ClassifDataModule() model.validation_step = None model.val_dataloader = None stopping = EarlyStopping(monitor='train_loss', min_delta=0.1, patience=0) trainer = Trainer( default_root_dir=tmpdir, callbacks=[stopping], overfit_batches=0.20, max_epochs=10, ) trainer.fit(model, datamodule=dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.current_epoch < trainer.max_epochs - 1
def test_callbacks_state_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint restores callbacks that persist state.""" dm = ClassifDataModule() model = ClassificationModel() callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=2, max_epochs=1, logger=False, callbacks=[checkpoint, callback_capture], ) assert checkpoint.best_model_path == "" assert checkpoint.best_model_score is None return trainer_args # initial training trainer = Trainer(**get_trainer_args()) with pytest.deprecated_call( match= "`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" ): trainer.fit(model, datamodule=dm) callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training trainer = Trainer(**get_trainer_args()) with pytest.deprecated_call( match= "`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" ): trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) assert len(callbacks_before_resume) == len(callback_capture.callbacks) for before, after in zip(callbacks_before_resume, callback_capture.callbacks): if isinstance(before, ModelCheckpoint): for attribute in ( "best_model_path", "best_model_score", "best_k_models", "kth_best_model_path", "kth_value", "last_model_path", ): assert getattr(before, attribute) == getattr(after, attribute)
def test_train_val_loop_only(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() model.validation_step = None model.validation_step_end = None model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, weights_summary=None, ) # fit model trainer.fit(model, datamodule=dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.callback_metrics['train_loss'] < 1.0
def test_fit_csv_logger(tmpdir): dm = ClassifDataModule() model = ClassificationModel() logger = CSVLogger(save_dir=tmpdir) trainer = Trainer(default_root_dir=tmpdir, max_steps=10, logger=logger, log_every_n_steps=1) trainer.fit(model, datamodule=dm) metrics_file = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE) assert os.path.isfile(metrics_file)
def test_optimization(tmpdir): seed_everything(42) dm = ClassifDataModule(length=1024) model = ClassificationModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="hpu", devices=1) # fit model trainer.fit(model, dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert dm.trainer is not None # validate result = trainer.validate(datamodule=dm) assert dm.trainer is not None assert result[0]["val_acc"] > 0.7 # test result = trainer.test(model, datamodule=dm) assert dm.trainer is not None test_result = result[0]["test_acc"] assert test_result > 0.6 # test saved model model_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(model_path) model = ClassificationModel.load_from_checkpoint(model_path) trainer = Trainer(default_root_dir=tmpdir, accelerator="hpu", devices=1) result = trainer.test(model, datamodule=dm) saved_result = result[0]["test_acc"] assert saved_result == test_result
def test_suggestion_parameters_work(tmpdir): """Test that default skipping does not alter results in basic case.""" dm = ClassifDataModule() model = ClassificationModel() # logger file to get meta trainer = Trainer(default_root_dir=tmpdir, max_epochs=3) lrfinder = trainer.tuner.lr_find(model, datamodule=dm) lr1 = lrfinder.suggestion(skip_begin=10) # default lr2 = lrfinder.suggestion(skip_begin=150) # way too high, should have an impact assert lr1 != lr2, "Skipping parameter did not influence learning rate"
def test_running_test_pretrained_model_cpu(tmpdir): """Verify test() on pretrained model.""" tutils.reset_seed() dm = ClassifDataModule() model = ClassificationModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, callbacks=[checkpoint], logger=logger, default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) # test we have good test accuracy tutils.assert_ok_model_acc(new_trainer, key='test_acc', thr=0.45)
def test_multi_gpu_none_backend(tmpdir): """Make sure when using multiple GPUs the user can't use `distributed_backend = None`.""" tutils.set_random_master_port() trainer_options = dict( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, limit_train_batches=0.2, limit_val_batches=0.2, gpus=2, ) dm = ClassifDataModule() model = ClassificationModel() tpipes.run_model_test(trainer_options, model, dm)
def test_multi_gpu_none_backend(tmpdir): """Make sure when using multiple GPUs the user can't use `accelerator = None`.""" tutils.set_random_main_port() trainer_options = dict( default_root_dir=tmpdir, enable_progress_bar=False, max_epochs=1, limit_train_batches=0.2, limit_val_batches=0.2, gpus=2, ) dm = ClassifDataModule() model = ClassificationModel() tpipes.run_model_test(trainer_options, model, dm)
def test_multi_gpu_early_stop_ddp_spawn(tmpdir): tutils.set_random_main_port() trainer_options = dict( default_root_dir=tmpdir, callbacks=[EarlyStopping(monitor="train_acc")], max_epochs=50, limit_train_batches=10, limit_val_batches=10, gpus=[0, 1], strategy="ddp_spawn", ) dm = ClassifDataModule() model = ClassificationModel() tpipes.run_model_test(trainer_options, model, dm)
def test_running_test_pretrained_model_distrib_dp(tmpdir): """Verify `test()` on pretrained model.""" tutils.set_random_master_port() dm = ClassifDataModule() model = CustomClassificationModelDP(lr=0.1) # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=5, limit_val_batches=5, callbacks=[checkpoint], logger=logger, gpus=[0, 1], accelerator='dp', default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) pretrained_model.cpu() dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_prediction_eval_model_template(pretrained_model, dataloader)
def test_resume_early_stopping_from_checkpoint(tmpdir): """Prevent regressions to bugs: https://github.com/PyTorchLightning/pytorch-lightning/issues/1464 https://github.com/PyTorchLightning/pytorch-lightning/issues/1463 """ seed_everything(42) model = ClassificationModel() dm = ClassifDataModule() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1) early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss") trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback, checkpoint_callback], num_sanity_val_steps=0, max_epochs=4, ) trainer.fit(model, datamodule=dm) assert len(early_stop_callback.saved_states) == 4 checkpoint_filepath = checkpoint_callback.kth_best_model_path # ensure state is persisted properly checkpoint = torch.load(checkpoint_filepath) # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[ checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}" assert checkpoint["callbacks"][es_name] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") new_trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, resume_from_checkpoint=checkpoint_filepath, callbacks=[early_stop_callback], ) with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"): new_trainer.fit(model)
def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" model = ClassificationModel() dm = ClassifDataModule() early_stop_callback = EarlyStopping(monitor='train_loss') expected_count = 4 trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], limit_train_batches=4, limit_val_batches=4, max_epochs=expected_count, ) trainer.fit(model, datamodule=dm) assert trainer.early_stopping_callback == early_stop_callback assert trainer.early_stopping_callbacks == [early_stop_callback] assert len(trainer.dev_debugger.early_stopping_history) == expected_count
def test_multi_cpu_model_ddp(tmpdir): """Make sure DDP works.""" tutils.set_random_master_port() trainer_options = dict( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, gpus=None, num_processes=2, accelerator='ddp_cpu', ) dm = ClassifDataModule() model = ClassificationModel() tpipes.run_model_test(trainer_options, model, data=dm, on_gpu=False)
def test_multi_cpu_model_ddp(tmpdir): """Make sure DDP works.""" tutils.set_random_main_port() trainer_options = dict( default_root_dir=tmpdir, enable_progress_bar=False, max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, gpus=None, num_processes=2, strategy="ddp_spawn", ) dm = ClassifDataModule() model = ClassificationModel() tpipes.run_model_test(trainer_options, model, data=dm, on_gpu=False)
def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" model = ClassificationModel() dm = ClassifDataModule() early_stop_callback = EarlyStopping(monitor="train_loss") early_stop_callback._run_early_stopping_check = Mock() expected_count = 4 trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], limit_train_batches=4, limit_val_batches=4, max_epochs=expected_count, enable_checkpointing=False, ) trainer.fit(model, datamodule=dm) assert trainer.early_stopping_callback == early_stop_callback assert trainer.early_stopping_callbacks == [early_stop_callback] assert early_stop_callback._run_early_stopping_check.call_count == expected_count
def test_try_resume_from_non_existing_checkpoint(tmpdir): """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error.""" dm = ClassifDataModule() model = ClassificationModel() checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=False, callbacks=[checkpoint_cb], limit_train_batches=2, limit_val_batches=2, ) # Generate checkpoint `last.ckpt` with BoringModel trainer.fit(model, datamodule=dm) # `True` if resume/restore successfully else `False` assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu) assert not trainer.checkpoint_connector.restore( str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)
def test_running_test_pretrained_model_distrib_dp(tmpdir): """Verify `test()` on pretrained model.""" tutils.set_random_master_port() class CustomClassificationModelDP(ClassificationModel): def _step(self, batch, batch_idx): x, y = batch logits = self(x) return {'logits': logits, 'y': y} def training_step(self, batch, batch_idx): _, y = batch out = self._step(batch, batch_idx) loss = F.cross_entropy(out['logits'], y) return loss def validation_step(self, batch, batch_idx): return self._step(batch, batch_idx) def test_step(self, batch, batch_idx): return self._step(batch, batch_idx) def validation_step_end(self, outputs): self.log('val_acc', self.valid_acc(outputs['logits'], outputs['y'])) dm = ClassifDataModule() model = CustomClassificationModelDP(lr=0.1) # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=5, limit_val_batches=5, callbacks=[checkpoint], logger=logger, gpus=[0, 1], accelerator='dp', default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) pretrained_model.cpu() dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_prediction(pretrained_model, dataloader)