class TrainerCallbackConfigMixin(ABC): def __init__(self): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class self.default_save_path = None self.save_checkpoint = None self.slurm_job_id = None def configure_checkpoint_callback(self): """ Weight path set in this priority: Checkpoint_callback's path (if passed in). User provided weights_saved_path Otherwise use os.getcwd() """ if self.checkpoint_callback is True: # init a default one if self.logger is not None: save_dir = (getattr(self.logger, 'save_dir', None) or getattr(self.logger, '_save_dir', None) or self.default_save_path) ckpt_path = os.path.join( save_dir, self.logger.name, f'version_{self.logger.version}', "checkpoints" ) else: ckpt_path = os.path.join(self.default_save_path, "checkpoints") self.checkpoint_callback = ModelCheckpoint( filepath=ckpt_path ) elif self.checkpoint_callback is False: self.checkpoint_callback = None if self.checkpoint_callback: # set the path for the callbacks self.checkpoint_callback.save_function = self.save_checkpoint # if checkpoint callback used, then override the weights path self.weights_save_path = self.checkpoint_callback.filepath # link to the trainer self.checkpoint_callback.set_trainer(self) # if weights_save_path is still none here, set to current working dir if self.weights_save_path is None: self.weights_save_path = self.default_save_path def configure_early_stopping(self, early_stop_callback): if early_stop_callback is True: self.early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, strict=True, verbose=True, mode='min' ) self.enable_early_stop = True elif early_stop_callback is None: self.early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, strict=False, verbose=False, mode='min' ) self.enable_early_stop = True elif not early_stop_callback: self.early_stop_callback = None self.enable_early_stop = False else: self.early_stop_callback = early_stop_callback self.enable_early_stop = True if self.early_stop_callback is not None: self.early_stop_callback.set_trainer(self)
def test_model_checkpoint_options(tmp_path): """Test ModelCheckpoint options.""" def mock_save_function(filepath): open(filepath, 'a').close() hparams = tutils.get_hparams() _ = LightningTestModel(hparams) # simulated losses save_dir = tmp_path / "1" save_dir.mkdir() losses = [10, 9, 2.8, 5, 2.5] # ----------------- # CASE K=-1 (all) checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for i, loss in enumerate(losses): checkpoint_callback._trainer.current_epoch = i checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end() file_lists = set(os.listdir(save_dir)) assert len(file_lists) == len( losses), "Should save all models when save_top_k=-1" # verify correct naming for i in range(0, len(losses)): assert f'_ckpt_epoch_{i}.ckpt' in file_lists save_dir = tmp_path / "2" save_dir.mkdir() # ----------------- # CASE K=0 (none) checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for i, loss in enumerate(losses): checkpoint_callback._trainer.current_epoch = i checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end() file_lists = os.listdir(save_dir) assert len(file_lists) == 0, "Should save 0 models when save_top_k=0" save_dir = tmp_path / "3" save_dir.mkdir() # ----------------- # CASE K=1 (2.5, epoch 4) checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix') checkpoint_callback.save_function = mock_save_function trainer = Trainer() checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for i, loss in enumerate(losses): checkpoint_callback._trainer.current_epoch = i checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end() file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 1, "Should save 1 model when save_top_k=1" assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists save_dir = tmp_path / "4" save_dir.mkdir() # ----------------- # CASE K=2 (2.5 epoch 4, 2.8 epoch 2) # make sure other files don't get deleted checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1) open(f'{save_dir}/other_file.ckpt', 'a').close() checkpoint_callback.save_function = mock_save_function trainer = Trainer() checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for i, loss in enumerate(losses): checkpoint_callback._trainer.current_epoch = i checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end() file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2' assert '_ckpt_epoch_4.ckpt' in file_lists assert '_ckpt_epoch_2.ckpt' in file_lists assert 'other_file.ckpt' in file_lists save_dir = tmp_path / "5" save_dir.mkdir() # ----------------- # CASE K=4 (save all 4 models) # multiple checkpoints within same epoch checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for loss in losses: checkpoint_callback._trainer.current_epoch = 0 checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end() file_lists = set(os.listdir(save_dir)) assert len( file_lists ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch' save_dir = tmp_path / "6" save_dir.mkdir() # ----------------- # CASE K=3 (save the 2nd, 3rd, 4th model) # multiple checkpoints within same epoch checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() checkpoint_callback.set_trainer(trainer) # emulate callback's calls during the training for loss in losses: checkpoint_callback._trainer.current_epoch = 0 checkpoint_callback._trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end() file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3' assert '_ckpt_epoch_0_v2.ckpt' in file_lists assert '_ckpt_epoch_0_v1.ckpt' in file_lists assert '_ckpt_epoch_0.ckpt' in file_lists