def test_train_loop_only(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() model.validation_step = None model.validation_step_end = None model.validation_epoch_end = None model.test_step = None model.test_step_end = None model.test_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False) # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.callback_metrics["train_loss"] < 1.0
def test_multi_gpu_early_stop_dp(tmpdir): """Make sure DDP works. with early stopping """ tutils.set_random_master_port() dm = ClassifDataModule() model = CustomClassificationModelDP() trainer_options = dict( default_root_dir=tmpdir, callbacks=[EarlyStopping(monitor="val_acc")], max_epochs=50, limit_train_batches=10, limit_val_batches=10, gpus=[0, 1], accelerator="dp", ) tpipes.run_model_test(trainer_options, model, dm)
def test_train_val_loop_only(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() model.validation_step = None model.validation_step_end = None model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, weights_summary=None, ) # fit model result = trainer.fit(model, datamodule=dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert result assert trainer.callback_metrics['train_loss'] < 1.0
def test_full_loop(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, weights_summary=None, deterministic=True, ) # fit model result = trainer.fit(model, dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert result # test result = trainer.test(datamodule=dm) assert result[0]['test_acc'] > 0.6
def run_checkpoint_test(tmpdir, save_full_weights): seed_everything(1) model = ModelParallelClassificationModel() dm = ClassifDataModule() ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=10, plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], gpus=2, precision=16, accumulate_grad_batches=2, callbacks=[ck] ) trainer.fit(model, datamodule=dm) results = trainer.test(model, datamodule=dm) assert results[0]['test_acc'] > 0.7 saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm) assert saved_results[0]['test_acc'] > 0.7 assert saved_results == results trainer = Trainer( default_root_dir=tmpdir, max_epochs=10, plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], gpus=2, precision=16, accumulate_grad_batches=2, callbacks=[ck], resume_from_checkpoint=ck.best_model_path ) results = trainer.test(model, datamodule=dm) assert results[0]['test_acc'] > 0.7 dm.predict_dataloader = dm.test_dataloader results = trainer.predict(datamodule=dm) assert results[-1] > 0.7
def test_resume_early_stopping_from_checkpoint(tmpdir): """ Prevent regressions to bugs: https://github.com/PyTorchLightning/pytorch-lightning/issues/1464 https://github.com/PyTorchLightning/pytorch-lightning/issues/1463 """ seed_everything(42) model = ClassificationModel() dm = ClassifDataModule() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1) early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss") trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback, checkpoint_callback], num_sanity_val_steps=0, max_epochs=4, ) trainer.fit(model, datamodule=dm) assert len(early_stop_callback.saved_states) == 4 checkpoint_filepath = checkpoint_callback.kth_best_model_path # ensure state is persisted properly checkpoint = torch.load(checkpoint_filepath) # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") new_trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, resume_from_checkpoint=checkpoint_filepath, callbacks=[early_stop_callback], ) with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"): new_trainer.fit(model)
def test_try_resume_from_non_existing_checkpoint(tmpdir): """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error.""" dm = ClassifDataModule() model = ClassificationModel() checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=False, callbacks=[checkpoint_cb], limit_train_batches=2, limit_val_batches=2, ) # Generate checkpoint `last.ckpt` with BoringModel trainer.fit(model, datamodule=dm) # `True` if resume/restore successfully else `False` assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu) assert not trainer.checkpoint_connector.restore( str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)
def test_lr_monitor_param_groups(tmpdir): """ Test that learning rates are extracted and logged for single lr scheduler. """ tutils.reset_seed() class CustomClassificationModel(ClassificationModel): def configure_optimizers(self): param_groups = [{ 'params': list(self.parameters())[:2], 'lr': self.lr * 0.1 }, { 'params': list(self.parameters())[2:], 'lr': self.lr }] optimizer = optim.Adam(param_groups) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) return [optimizer], [lr_scheduler] model = CustomClassificationModel() dm = ClassifDataModule() lr_monitor = LearningRateMonitor() trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor], ) trainer.fit(model, datamodule=dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert lr_monitor.lrs, 'No learning rates logged' assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers), \ 'Number of learning rates logged does not match number of param groups' assert lr_monitor.lr_sch_names == ['lr-Adam'] assert list(lr_monitor.lrs.keys()) == ['lr-Adam/pg1', 'lr-Adam/pg2'], \ 'Names of learning rates not set correctly'
def test_optimization(tmpdir): seed_everything(42) dm = ClassifDataModule(length=1024) model = IPUClassificationModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, ipus=2, ) # fit model trainer.fit(model, dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert dm.trainer is not None # validate result = trainer.validate(datamodule=dm) assert dm.trainer is not None assert result[0]['val_acc'] > 0.7 # test result = trainer.test(model, datamodule=dm) assert dm.trainer is not None test_result = result[0]['test_acc'] assert test_result > 0.6 # test saved model model_path = os.path.join(tmpdir, 'model.pt') trainer.save_checkpoint(model_path) model = IPUClassificationModel.load_from_checkpoint(model_path) trainer = Trainer(default_root_dir=tmpdir, ipus=2) result = trainer.test(model, datamodule=dm) saved_result = result[0]['test_acc'] assert saved_result == test_result
def test_evaluate(tmpdir, trainer_kwargs): tutils.set_random_main_port() seed_everything(1) dm = ClassifDataModule() model = CustomClassificationModelDP() trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=10, limit_val_batches=10, **trainer_kwargs) trainer.fit(model, datamodule=dm) assert "ckpt" in trainer.checkpoint_callback.best_model_path old_weights = model.layer_0.weight.clone().detach().cpu() trainer.validate(datamodule=dm) trainer.test(datamodule=dm) # make sure weights didn't change new_weights = model.layer_0.weight.clone().detach().cpu() torch.testing.assert_allclose(old_weights, new_weights)
def test_early_stopping_no_val_step(tmpdir): """Test that early stopping callback falls back to training metrics when no validation defined.""" model = ClassificationModel() dm = ClassifDataModule() model.validation_step = None model.val_dataloader = None stopping = EarlyStopping(monitor='train_loss', min_delta=0.1, patience=0, check_on_train_epoch_end=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[stopping], overfit_batches=0.20, max_epochs=10, ) trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.current_epoch < trainer.max_epochs - 1
def test_callbacks_state_resume_from_checkpoint(tmpdir): """ Test that resuming from a checkpoint restores callbacks that persist state. """ dm = ClassifDataModule() model = ClassificationModel() callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict(default_root_dir=tmpdir, max_steps=1, logger=False, callbacks=[ checkpoint, callback_capture, ]) assert checkpoint.best_model_path == "" assert checkpoint.best_model_score is None return trainer_args # initial training trainer = Trainer(**get_trainer_args()) trainer.fit(model, datamodule=dm) callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt")) trainer.fit(model, datamodule=dm) assert len(callbacks_before_resume) == len(callback_capture.callbacks) for before, after in zip(callbacks_before_resume, callback_capture.callbacks): if isinstance(before, ModelCheckpoint): assert before.best_model_path == after.best_model_path assert before.best_model_score == after.best_model_score
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches( tmpdir, offload_optimizer): """Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works.""" seed_everything(42) class VerificationCallback(Callback): def __init__(self): self.on_train_batch_start_called = False def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: deepspeed_engine = trainer.strategy.model assert trainer.global_step == deepspeed_engine.global_steps self.on_train_batch_start_called = True model = ModelParallelClassificationModel() dm = ClassifDataModule() verification_callback = VerificationCallback() trainer = Trainer( default_root_dir=tmpdir, enable_progress_bar=False, # TODO: this test fails with max_epochs >1 as there are leftover batches per epoch. # there's divergence in how Lightning handles the last batch of the epoch with how DeepSpeed does it. # we step the optimizers on the last batch but DeepSpeed keeps the accumulation for the next epoch max_epochs=1, strategy=DeepSpeedStrategy(stage=2, offload_optimizer=offload_optimizer), accelerator="gpu", devices=2, limit_train_batches=5, limit_val_batches=2, precision=16, accumulate_grad_batches=2, callbacks=[verification_callback], ) assert trainer.limit_train_batches % trainer.accumulate_grad_batches != 0, "leftover batches should be tested" trainer.fit(model, datamodule=dm) assert verification_callback.on_train_batch_start_called
def test_lr_monitor_param_groups(tmpdir): """Test that learning rates are extracted and logged for single lr scheduler.""" tutils.reset_seed() class CustomClassificationModel(ClassificationModel): def configure_optimizers(self): param_groups = [ { "params": list(self.parameters())[:2], "lr": self.lr * 0.1 }, { "params": list(self.parameters())[2:], "lr": self.lr }, ] optimizer = optim.Adam(param_groups) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) return [optimizer], [lr_scheduler] model = CustomClassificationModel() dm = ClassifDataModule() lr_monitor = LearningRateMonitor() trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]) trainer.fit(model, datamodule=dm) assert lr_monitor.lrs, "No learning rates logged" assert len(lr_monitor.lrs) == 2 * len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == [ "lr-Adam/pg1", "lr-Adam/pg2" ], "Names of learning rates not set correctly"
def test_running_test_pretrained_model_cpu(tmpdir): """Verify test() on pretrained model.""" tutils.reset_seed() dm = ClassifDataModule() model = ClassificationModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, callbacks=[checkpoint], logger=logger, default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) # test we have good test accuracy tutils.assert_ok_model_acc(new_trainer, key='test_acc', thr=0.45)
def _deepspeed_multigpu_stage_2_accumulated_grad_batches( tmpdir, offload_optimizer): """ Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works. """ seed_everything(42) class VerificationCallback(Callback): def __init__(self): self.on_train_batch_start_called = False def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int) -> None: deepspeed_engine = trainer.training_type_plugin.model assert trainer.global_step == deepspeed_engine.global_steps self.on_train_batch_start_called = True model = ModelParallelClassificationModel() dm = ClassifDataModule() verification_callback = VerificationCallback() trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=5, plugins=[ DeepSpeedPlugin(stage=2, offload_optimizer=offload_optimizer) ], gpus=2, limit_val_batches=2, precision=16, accumulate_grad_batches=2, callbacks=[verification_callback], ) trainer.fit(model, datamodule=dm) assert verification_callback.on_train_batch_start_called
def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): """Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the optimizer state and scheduler states cannot be restored.""" dm = ClassifDataModule() model = BoringModel() checkpoint_path = os.path.join(tmpdir, "model.pt") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) trainer.save_checkpoint(checkpoint_path) trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, strategy=DeepSpeedStrategy(stage=3, load_full_weights=True), gpus=1, precision=16, ) with pytest.warns( UserWarning, match="A single checkpoint file has been given. This means optimizer states cannot be restored. " "If you'd like to restore these states, you must " "provide a path to the originally saved DeepSpeed checkpoint.", ): trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_path)
def test_running_test_pretrained_model_distrib_dp(tmpdir): """Verify `test()` on pretrained model.""" tutils.set_random_master_port() class CustomClassificationModelDP(ClassificationModel): def _step(self, batch, batch_idx): x, y = batch logits = self(x) return {'logits': logits, 'y': y} def training_step(self, batch, batch_idx): _, y = batch out = self._step(batch, batch_idx) loss = F.cross_entropy(out['logits'], y) return loss def validation_step(self, batch, batch_idx): return self._step(batch, batch_idx) def test_step(self, batch, batch_idx): return self._step(batch, batch_idx) def validation_step_end(self, outputs): self.log('val_acc', self.valid_acc(outputs['logits'], outputs['y'])) dm = ClassifDataModule() model = CustomClassificationModelDP(lr=0.1) # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=5, limit_val_batches=5, callbacks=[checkpoint], logger=logger, gpus=[0, 1], accelerator='dp', default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) pretrained_model.cpu() dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_prediction(pretrained_model, dataloader)
def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" model = CustomClassificationModelDP(lr=0.1) dm = ClassifDataModule() trainer_options = dict(max_epochs=1, accelerator="gpu", devices=2, strategy="dp", default_root_dir=tmpdir) # get logger logger = tutils.get_default_logger(tmpdir) # exp file to get weights # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) # add these to the trainer options trainer_options["logger"] = logger trainer_options["callbacks"] = [checkpoint] # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # track epoch before saving real_global_epoch = trainer.current_epoch # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" # --------------------------- # HPC LOAD/SAVE # --------------------------- # save # save logger to make sure we get all the metrics if logger: logger.finalize("finished") hpc_save_path = trainer._checkpoint_connector.hpc_save_path(tmpdir) trainer.save_checkpoint(hpc_save_path) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) trainer_options["logger"] = new_logger trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)] trainer_options["limit_train_batches"] = 0.5 trainer_options["limit_val_batches"] = 0.2 trainer_options["max_epochs"] = 1 new_trainer = Trainer(**trainer_options) class CustomModel(CustomClassificationModelDP): def __init__(self): super().__init__() self.on_train_start_called = False def on_validation_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 dataloader = dm.val_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) # new model model = CustomModel() # validate new model which should load hpc weights new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path) # test freeze on gpu model.freeze() model.unfreeze()
def test_trainer_properties_restore_ckpt_path(tmpdir): """Test that required trainer properties are set correctly when resuming from checkpoint in different phases.""" class CustomClassifModel(ClassificationModel): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] model = CustomClassifModel() dm = ClassifDataModule() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) trainer_args = dict( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, logger=False, callbacks=[checkpoint_callback], num_sanity_val_steps=0, ) trainer = Trainer(**trainer_args) trainer.fit(model, datamodule=dm) resume_ckpt = str(tmpdir / "last.ckpt") state_dict = torch.load(resume_ckpt) trainer_args.update({ "max_epochs": 3, "enable_checkpointing": False, "callbacks": [] }) class CustomClassifModel(CustomClassifModel): def _is_equal(self, a, b): if isinstance(a, torch.Tensor): return torch.all(torch.eq(a, b)) if isinstance(a, Mapping): return all( self._is_equal(a.get(k, None), b.get(k, None)) for k in b.keys()) return a == b def _check_optimizers(self): return all( self._is_equal(self.trainer.optimizers[i].state_dict(), state_dict["optimizer_states"][i]) for i in range(len(self.trainer.optimizers))) def _check_schedulers(self): return all( self._is_equal( self.trainer.lr_schedulers[i]["scheduler"].state_dict(), state_dict["lr_schedulers"][i]) for i in range(len(self.trainer.lr_schedulers))) def _check_model_state_dict(self): for k in self.state_dict(): yield self._is_equal(self.state_dict()[k], state_dict["state_dict"][k]) def _test_on_val_test_predict_tune_start(self): assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] assert all(self._check_model_state_dict()) # no optimizes and schedulers are loaded otherwise if self.trainer.state.fn != TrainerFn.TUNING: return assert not self._check_optimizers() assert not self._check_schedulers() def on_train_start(self): if self.trainer.state.fn == TrainerFn.TUNING: self._test_on_val_test_predict_tune_start() else: assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] assert all(self._check_model_state_dict()) assert self._check_optimizers() assert self._check_schedulers() def on_validation_start(self): if self.trainer.state.fn == TrainerFn.VALIDATING: self._test_on_val_test_predict_tune_start() def on_test_start(self): self._test_on_val_test_predict_tune_start() for fn in ("fit", "validate", "test", "predict"): model = CustomClassifModel() dm = ClassifDataModule() trainer_args["auto_scale_batch_size"] = (fn == "tune", ) trainer = Trainer(**trainer_args) trainer_fn = getattr(trainer, fn) trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)
def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): """Test to ensure with Stage 3 and single GPU that we can resume training.""" initial_model = ModelParallelClassificationModel() dm = ClassifDataModule() ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) initial_trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=1, precision=16, callbacks=[ck], enable_progress_bar=False, enable_model_summary=False, ) initial_trainer.fit(initial_model, datamodule=dm) class TestCallback(Callback): def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: original_deepspeed_strategy = initial_trainer.strategy current_deepspeed_strategy = trainer.strategy assert isinstance(original_deepspeed_strategy, DeepSpeedStrategy) assert isinstance(current_deepspeed_strategy, DeepSpeedStrategy) # assert optimizer states are the correctly loaded original_optimizer_dict = original_deepspeed_strategy.deepspeed_engine.optimizer.state_dict( ) current_optimizer_dict = current_deepspeed_strategy.deepspeed_engine.optimizer.state_dict( ) for orig_tensor, current_tensor in zip( original_optimizer_dict["fp32_flat_groups"], current_optimizer_dict["fp32_flat_groups"]): assert torch.all(orig_tensor.eq(current_tensor)) # assert model state is loaded correctly for current_param, initial_param in zip( pl_module.parameters(), initial_model.parameters()): assert torch.equal(current_param.cpu(), initial_param.cpu()) # assert epoch has correctly been restored assert trainer.current_epoch == 1 # assert lr-scheduler states are loaded correctly original_lr_scheduler = initial_trainer.lr_scheduler_configs[ 0].scheduler current_lr_scheduler = trainer.lr_scheduler_configs[0].scheduler assert original_lr_scheduler.state_dict( ) == current_lr_scheduler.state_dict() model = ModelParallelClassificationModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=1, precision=16, callbacks=TestCallback(), enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(model, datamodule=dm, ckpt_path=ck.best_model_path)
def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" model = CustomClassificationModelDP(lr=0.1) dm = ClassifDataModule() trainer_options = dict(max_epochs=1, gpus=2, accelerator='dp', default_root_dir=tmpdir) # get logger logger = tutils.get_default_logger(tmpdir) # exp file to get weights # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) # add these to the trainer options trainer_options['logger'] = logger trainer_options['callbacks'] = [checkpoint] # fit model trainer = Trainer(**trainer_options) trainer.is_slurm_managing_tasks = True trainer.fit(model, datamodule=dm) # track epoch before saving. Increment since we finished the current epoch, don't want to rerun real_global_epoch = trainer.current_epoch + 1 # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # --------------------------- # HPC LOAD/SAVE # --------------------------- # save trainer.checkpoint_connector.hpc_save(tmpdir, logger) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) trainer_options['logger'] = new_logger trainer_options['callbacks'] = [ModelCheckpoint(dirpath=tmpdir)] trainer_options['limit_train_batches'] = 0.5 trainer_options['limit_val_batches'] = 0.2 trainer_options['max_epochs'] = 1 new_trainer = Trainer(**trainer_options) class CustomModel(CustomClassificationModelDP): def __init__(self): super().__init__() self.on_train_start_called = False # set the epoch start hook so we can predict before the model does the full training def on_train_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model new_trainer._running_stage = RunningStage.EVALUATING dataloader = self.train_dataloader() tpipes.run_prediction_eval_model_template( self.trainer.lightning_module, dataloader=dataloader) self.on_train_start_called = True # new model model = CustomModel() # fit new model which should load hpc weights new_trainer.fit(model, datamodule=dm) assert model.on_train_start_called # test freeze on gpu model.freeze() model.unfreeze()