def test_optimization(tmpdir): seed_everything(42) dm = ClassifDataModule(length=1024) model = ClassificationModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="hpu", devices=1) # fit model trainer.fit(model, dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert dm.trainer is not None # validate result = trainer.validate(datamodule=dm) assert dm.trainer is not None assert result[0]["val_acc"] > 0.7 # test result = trainer.test(model, datamodule=dm) assert dm.trainer is not None test_result = result[0]["test_acc"] assert test_result > 0.6 # test saved model model_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(model_path) model = ClassificationModel.load_from_checkpoint(model_path) trainer = Trainer(default_root_dir=tmpdir, accelerator="hpu", devices=1) result = trainer.test(model, datamodule=dm) saved_result = result[0]["test_acc"] assert saved_result == test_result
def main(): seed_everything(4321) parser = ArgumentParser(add_help=False) parser = Trainer.add_argparse_args(parser) parser.add_argument("--trainer_method", default="fit") parser.add_argument("--tmpdir") parser.add_argument("--workdir") parser.set_defaults(accelerator="gpu", devices=2) parser.set_defaults(strategy="ddp") args = parser.parse_args() dm = ClassifDataModule() model = ClassificationModel() trainer = Trainer.from_argparse_args(args) if args.trainer_method == "fit": trainer.fit(model, datamodule=dm) result = None elif args.trainer_method == "test": result = trainer.test(model, datamodule=dm) elif args.trainer_method == "fit_test": trainer.fit(model, datamodule=dm) result = trainer.test(model, datamodule=dm) else: raise ValueError(f"Unsupported: {args.trainer_method}") result_ext = { "status": "complete", "method": args.trainer_method, "result": result } file_path = os.path.join(args.tmpdir, "ddp.result") torch.save(result_ext, file_path)
def test_callbacks_references_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint sets references as expected.""" dm = ClassifDataModule() model = ClassificationModel() args = { "default_root_dir": tmpdir, "max_steps": 1, "logger": False, "limit_val_batches": 2, "num_sanity_val_steps": 0, } # initial training checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer = Trainer(**args, callbacks=[checkpoint]) assert checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback trainer.fit(model, datamodule=dm) # resumed training new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) # pass in a new checkpoint object, which should take # precedence over the one in the last.ckpt file trainer = Trainer(**args, callbacks=[new_checkpoint]) assert checkpoint is not new_checkpoint assert new_checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): """Verify `test()` on pretrained model.""" tutils.set_random_main_port() dm = ClassifDataModule() model = ClassificationModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( enable_progress_bar=False, max_epochs=2, limit_train_batches=2, limit_val_batches=2, callbacks=[checkpoint], logger=logger, accelerator="gpu", devices=[0, 1], strategy="ddp_spawn", default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir))) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1)
def test_fit_csv_logger(tmpdir): dm = ClassifDataModule() model = ClassificationModel() logger = CSVLogger(save_dir=tmpdir) trainer = Trainer(default_root_dir=tmpdir, max_steps=10, logger=logger, log_every_n_steps=1) trainer.fit(model, datamodule=dm) metrics_file = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE) assert os.path.isfile(metrics_file)
def test_train_val_loop_only(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() model.validation_step = None model.validation_step_end = None model.validation_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False) # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.callback_metrics["train_loss"] < 1.0
def test_multi_gpu_early_stop_ddp_spawn(tmpdir): tutils.set_random_main_port() trainer_options = dict( default_root_dir=tmpdir, callbacks=[EarlyStopping(monitor="train_acc")], max_epochs=50, limit_train_batches=10, limit_val_batches=10, accelerator="gpu", devices=[0, 1], strategy="ddp_spawn", ) dm = ClassifDataModule() model = ClassificationModel() tpipes.run_model_test(trainer_options, model, dm)
def test_callbacks_state_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint restores callbacks that persist state.""" dm = ClassifDataModule() model = ClassificationModel() callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=2, max_epochs=1, logger=False, callbacks=[checkpoint, callback_capture], ) assert checkpoint.best_model_path == "" assert checkpoint.best_model_score is None return trainer_args # initial training trainer = Trainer(**get_trainer_args()) with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"): trainer.fit(model, datamodule=dm) callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training trainer = Trainer(**get_trainer_args()) with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"): trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) assert len(callbacks_before_resume) == len(callback_capture.callbacks) for before, after in zip(callbacks_before_resume, callback_capture.callbacks): if isinstance(before, ModelCheckpoint): for attribute in ( "best_model_path", "best_model_score", "best_k_models", "kth_best_model_path", "kth_value", "last_model_path", ): assert getattr(before, attribute) == getattr(after, attribute)
def test_resume_early_stopping_from_checkpoint(tmpdir): """Prevent regressions to bugs: https://github.com/Lightning-AI/lightning/issues/1464 https://github.com/Lightning-AI/lightning/issues/1463 """ seed_everything(42) model = ClassificationModel() dm = ClassifDataModule() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1) early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss") trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback, checkpoint_callback], num_sanity_val_steps=0, max_epochs=4, ) trainer.fit(model, datamodule=dm) assert len(early_stop_callback.saved_states) == 4 checkpoint_filepath = checkpoint_callback.kth_best_model_path # ensure state is persisted properly checkpoint = torch.load(checkpoint_filepath) # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[ checkpoint["epoch"]] assert 4 == len(early_stop_callback.saved_states) es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}" assert checkpoint["callbacks"][es_name] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") new_trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, callbacks=[early_stop_callback], ) with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"): new_trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_filepath)
def test_multi_cpu_model_ddp(tmpdir): """Make sure DDP works.""" tutils.set_random_main_port() trainer_options = dict( default_root_dir=tmpdir, enable_progress_bar=False, max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, accelerator="cpu", devices=2, strategy="ddp_spawn", ) dm = ClassifDataModule() model = ClassificationModel() tpipes.run_model_test(trainer_options, model, data=dm)
def test_early_stopping_no_val_step(tmpdir): """Test that early stopping callback falls back to training metrics when no validation defined.""" model = ClassificationModel() dm = ClassifDataModule() model.validation_step = None model.val_dataloader = None stopping = EarlyStopping(monitor="train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True) trainer = Trainer(default_root_dir=tmpdir, callbacks=[stopping], overfit_batches=0.20, max_epochs=10) trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.current_epoch < trainer.max_epochs - 1
def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" model = ClassificationModel() dm = ClassifDataModule() early_stop_callback = EarlyStopping(monitor="train_loss") early_stop_callback._run_early_stopping_check = Mock() expected_count = 4 trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], limit_train_batches=4, limit_val_batches=4, max_epochs=expected_count, enable_checkpointing=False, ) trainer.fit(model, datamodule=dm) assert trainer.early_stopping_callback == early_stop_callback assert trainer.early_stopping_callbacks == [early_stop_callback] assert early_stop_callback._run_early_stopping_check.call_count == expected_count
def test_evaluate(tmpdir, trainer_kwargs): tutils.set_random_main_port() seed_everything(1) dm = ClassifDataModule() model = CustomClassificationModelDP() trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=10, limit_val_batches=10, **trainer_kwargs) trainer.fit(model, datamodule=dm) assert "ckpt" in trainer.checkpoint_callback.best_model_path old_weights = model.layer_0.weight.clone().detach().cpu() trainer.validate(datamodule=dm) trainer.test(datamodule=dm) # make sure weights didn't change new_weights = model.layer_0.weight.clone().detach().cpu() torch_test_assert_close(old_weights, new_weights)
def test_full_loop(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False, deterministic=True) # fit model trainer.fit(model, dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert dm.trainer is not None # validate result = trainer.validate(model, dm) assert dm.trainer is not None assert result[0]["val_acc"] > 0.7 # test result = trainer.test(model, dm) assert dm.trainer is not None assert result[0]["test_acc"] > 0.6
def test_multi_gpu_early_stop_dp(tmpdir): """Make sure DDP works. with early stopping """ tutils.set_random_main_port() dm = ClassifDataModule() model = CustomClassificationModelDP() trainer_options = dict( default_root_dir=tmpdir, callbacks=[EarlyStopping(monitor="val_acc")], max_epochs=50, limit_train_batches=10, limit_val_batches=10, accelerator="gpu", devices=[0, 1], strategy="dp", ) tpipes.run_model_test(trainer_options, model, dm)
def test_lr_monitor_param_groups(tmpdir): """Test that learning rates are extracted and logged for single lr scheduler.""" tutils.reset_seed() class CustomClassificationModel(ClassificationModel): def configure_optimizers(self): param_groups = [ { "params": list(self.parameters())[:2], "lr": self.lr * 0.1 }, { "params": list(self.parameters())[2:], "lr": self.lr }, ] optimizer = optim.Adam(param_groups) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) return [optimizer], [lr_scheduler] model = CustomClassificationModel() dm = ClassifDataModule() lr_monitor = LearningRateMonitor() trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]) trainer.fit(model, datamodule=dm) assert lr_monitor.lrs, "No learning rates logged" assert len(lr_monitor.lrs) == 2 * len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == [ "lr-Adam/pg1", "lr-Adam/pg2" ], "Names of learning rates not set correctly"
def test_running_test_pretrained_model_cpu(tmpdir): """Verify test() on pretrained model.""" tutils.reset_seed() dm = ClassifDataModule() model = ClassificationModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( enable_progress_bar=False, max_epochs=2, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, callbacks=[checkpoint], logger=logger, default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) # test we have good test accuracy tutils.assert_ok_model_acc(new_trainer, key="test_acc", thr=0.45)
def test_trainer_properties_restore_ckpt_path(tmpdir): """Test that required trainer properties are set correctly when resuming from checkpoint in different phases.""" class CustomClassifModel(ClassificationModel): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] model = CustomClassifModel() dm = ClassifDataModule() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) trainer_args = dict( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, limit_predict_batches=2, logger=False, callbacks=[checkpoint_callback], num_sanity_val_steps=0, ) trainer = Trainer(**trainer_args) trainer.fit(model, datamodule=dm) resume_ckpt = str(tmpdir / "last.ckpt") state_dict = torch.load(resume_ckpt) trainer_args.update({"max_epochs": 3, "enable_checkpointing": False, "callbacks": []}) class CustomClassifModel(CustomClassifModel): def _is_equal(self, a, b): if isinstance(a, torch.Tensor): return torch.all(torch.eq(a, b)) if isinstance(a, Mapping): return all(self._is_equal(a.get(k, None), b.get(k, None)) for k in b.keys()) return a == b def _check_optimizers(self): return all( self._is_equal(optimizer.state_dict(), state) for optimizer, state in zip(self.trainer.optimizers, state_dict["optimizer_states"]) ) def _check_schedulers(self): return all( self._is_equal(config.scheduler.state_dict(), state) for config, state in zip(self.trainer.lr_scheduler_configs, state_dict["lr_schedulers"]) ) def _check_model_state_dict(self): return all( self._is_equal(actual, expected) for actual, expected in zip(self.state_dict(), state_dict["state_dict"]) ) def _test_on_val_test_predict_tune_start(self): assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() # no optimizes and schedulers are loaded otherwise if self.trainer.state.fn != TrainerFn.TUNING: return assert not self._check_optimizers() assert not self._check_schedulers() def on_train_start(self): if self.trainer.state.fn == TrainerFn.TUNING: self._test_on_val_test_predict_tune_start() else: assert self.trainer.current_epoch == state_dict["epoch"] + 1 assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() assert self._check_optimizers() assert self._check_schedulers() def on_validation_start(self): if self.trainer.state.fn == TrainerFn.VALIDATING: self._test_on_val_test_predict_tune_start() def on_test_start(self): self._test_on_val_test_predict_tune_start() for fn in ("fit", "validate", "test", "predict"): model = CustomClassifModel() dm = ClassifDataModule() trainer_args["auto_scale_batch_size"] = (fn == "tune",) trainer = Trainer(**trainer_args) trainer_fn = getattr(trainer, fn) trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)
def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" model = CustomClassificationModelDP(lr=0.1) dm = ClassifDataModule() trainer_options = dict(max_epochs=1, accelerator="gpu", devices=2, strategy="dp", default_root_dir=tmpdir) # get logger logger = tutils.get_default_logger(tmpdir) # exp file to get weights # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) # add these to the trainer options trainer_options["logger"] = logger trainer_options["callbacks"] = [checkpoint] # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # track epoch before saving real_global_epoch = trainer.current_epoch # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" # --------------------------- # HPC LOAD/SAVE # --------------------------- # save # save logger to make sure we get all the metrics if logger: logger.finalize("finished") hpc_save_path = trainer._checkpoint_connector.hpc_save_path(tmpdir) trainer.save_checkpoint(hpc_save_path) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) trainer_options["logger"] = new_logger trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)] trainer_options["limit_train_batches"] = 0.5 trainer_options["limit_val_batches"] = 0.2 trainer_options["max_epochs"] = 1 new_trainer = Trainer(**trainer_options) class CustomModel(CustomClassificationModelDP): def __init__(self): super().__init__() self.on_train_start_called = False def on_validation_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 dataloader = dm.val_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) # new model model = CustomModel() # validate new model which should load hpc weights new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path) # test freeze on gpu model.freeze() model.unfreeze()