def test_default_logger_callbacks_cpu_model(tmpdir): """Test each of the trainer options.""" trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, gradient_clip_val=1.0, overfit_batches=0.20, enable_progress_bar=False, limit_train_batches=0.01, limit_val_batches=0.01, ) model = BoringModel() tpipes.run_model_test_without_loggers(trainer_options, model, min_acc=0.01) # test freeze on cpu model.freeze() model.unfreeze()
def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" model = BoringModel() trainer_options = dict(max_epochs=1, gpus=2, accelerator='dp', default_root_dir=tmpdir) # get logger logger = tutils.get_default_logger(tmpdir) # exp file to get weights # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) # add these to the trainer options trainer_options['logger'] = logger trainer_options['callbacks'] = [checkpoint] # fit model trainer = Trainer(**trainer_options) trainer.is_slurm_managing_tasks = True trainer.fit(model) # track epoch before saving. Increment since we finished the current epoch, don't want to rerun real_global_epoch = trainer.current_epoch + 1 # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # --------------------------- # HPC LOAD/SAVE # --------------------------- # save trainer.checkpoint_connector.hpc_save(tmpdir, logger) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) trainer_options['logger'] = new_logger trainer_options['callbacks'] = [ModelCheckpoint(dirpath=tmpdir)] trainer_options['limit_train_batches'] = 0.5 trainer_options['limit_val_batches'] = 0.2 trainer_options['max_epochs'] = 1 new_trainer = Trainer(**trainer_options) class CustomModel(BoringModel): def __init__(self): super().__init__() self.on_train_start_called = False # set the epoch start hook so we can predict before the model does the full training def on_train_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() dp_model.module.module.running_stage = RunningStage.EVALUATING dataloader = self.train_dataloader() tpipes.run_prediction(self.trainer.lightning_module, dataloader) self.on_train_start_called = True # new model model = CustomModel() # fit new model which should load hpc weights new_trainer.fit(model) assert model.on_train_start_called # test freeze on gpu model.freeze() model.unfreeze()