def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files): """Test ModelCheckpoint options.""" def mock_save_function(filepath, *args): open(filepath, 'a').close() # simulated losses losses = [10, 9, 2.8, 5, 2.5] checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(tmpdir)) assert len(file_lists) == len(expected_files), \ "Should save %i models when save_top_k=%i" % (len(expected_files), save_top_k) # verify correct naming for fname in expected_files: assert fname in file_lists
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files): """Test ModelCheckpoint options.""" def mock_save_function(filepath, *args): open(filepath, 'a').close() # simulated losses losses = [10, 9, 2.8, 5, 2.5] checkpoint_callback = ModelCheckpoint(tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.logger_connector.callback_metrics = { 'checkpoint_on': torch.tensor(loss) } checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(tmpdir)) assert len(file_lists) == len(expected_files), ( f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}" ) # verify correct naming for fname in expected_files: assert fname in file_lists
def init_default_checkpoint_callback(self, checkpoint_callback): if checkpoint_callback is True: checkpoint_callback = ModelCheckpoint(filepath=None) elif checkpoint_callback is False: checkpoint_callback = None if checkpoint_callback: checkpoint_callback.save_function = self.trainer.save_checkpoint return checkpoint_callback
def test_v1_5_0_model_checkpoint_save_function(): model_ckpt = ModelCheckpoint() with pytest.deprecated_call( match= "Property `save_function` in `ModelCheckpoint` is deprecated in v1.3" ): model_ckpt.save_function = lambda *_, **__: None with pytest.deprecated_call( match= "Property `save_function` in `ModelCheckpoint` is deprecated in v1.3" ): _ = model_ckpt.save_function
def configure_checkpoint_callback(self, checkpoint_callback): if checkpoint_callback is True: # when no val step is defined, use 'loss' otherwise 'val_loss' train_step_only = not self.is_overridden('validation_step') monitor_key = 'loss' if train_step_only else 'val_loss' checkpoint_callback = ModelCheckpoint(filepath=None, monitor=monitor_key) elif checkpoint_callback is False: checkpoint_callback = None if checkpoint_callback: checkpoint_callback.save_function = self.save_checkpoint return checkpoint_callback
def configure_checkpoint_callback(self, checkpoint_callback): """ Weight path set in this priority: Checkpoint_callback's path (if passed in). User provided weights_saved_path Otherwise use os.getcwd() """ if checkpoint_callback is True: # when no val step is defined, use 'loss' otherwise 'val_loss' train_step_only = not self.is_overridden('validation_step') monitor_key = 'loss' if train_step_only else 'val_loss' checkpoint_callback = ModelCheckpoint(filepath=None, monitor=monitor_key) elif checkpoint_callback is False: checkpoint_callback = None if checkpoint_callback: checkpoint_callback.save_function = self.save_checkpoint # if weights_save_path is still none here, set to current working dir if self.weights_save_path is None: self.weights_save_path = self.default_root_dir return checkpoint_callback
def test_model_checkpoint_options(tmp_path): """Test ModelCheckpoint options.""" def mock_save_function(filepath): open(filepath, 'a').close() hparams = tutils.get_hparams() _ = LightningTestModel(hparams) # simulated losses save_dir = tmp_path / "1" save_dir.mkdir() losses = [10, 9, 2.8, 5, 2.5] # ----------------- # CASE K=-1 (all) w = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1) w.save_function = mock_save_function for i, loss in enumerate(losses): w.on_epoch_end(i, logs={'val_loss': loss}) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == len( losses), "Should save all models when save_top_k=-1" # verify correct naming for i in range(0, len(losses)): assert f'_ckpt_epoch_{i}.ckpt' in file_lists save_dir = tmp_path / "2" save_dir.mkdir() # ----------------- # CASE K=0 (none) w = ModelCheckpoint(save_dir, save_top_k=0, verbose=1) w.save_function = mock_save_function for i, loss in enumerate(losses): w.on_epoch_end(i, logs={'val_loss': loss}) file_lists = os.listdir(save_dir) assert len(file_lists) == 0, "Should save 0 models when save_top_k=0" save_dir = tmp_path / "3" save_dir.mkdir() # ----------------- # CASE K=1 (2.5, epoch 4) w = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix') w.save_function = mock_save_function for i, loss in enumerate(losses): w.on_epoch_end(i, logs={'val_loss': loss}) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 1, "Should save 1 model when save_top_k=1" assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists save_dir = tmp_path / "4" save_dir.mkdir() # ----------------- # CASE K=2 (2.5 epoch 4, 2.8 epoch 2) # make sure other files don't get deleted w = ModelCheckpoint(save_dir, save_top_k=2, verbose=1) open(f'{save_dir}/other_file.ckpt', 'a').close() w.save_function = mock_save_function for i, loss in enumerate(losses): w.on_epoch_end(i, logs={'val_loss': loss}) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2' assert '_ckpt_epoch_4.ckpt' in file_lists assert '_ckpt_epoch_2.ckpt' in file_lists assert 'other_file.ckpt' in file_lists save_dir = tmp_path / "5" save_dir.mkdir() # ----------------- # CASE K=4 (save all 4 models) # multiple checkpoints within same epoch w = ModelCheckpoint(save_dir, save_top_k=4, verbose=1) w.save_function = mock_save_function for loss in losses: w.on_epoch_end(0, logs={'val_loss': loss}) file_lists = set(os.listdir(save_dir)) assert len( file_lists ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch' save_dir = tmp_path / "6" save_dir.mkdir() # ----------------- # CASE K=3 (save the 2nd, 3rd, 4th model) # multiple checkpoints within same epoch w = ModelCheckpoint(save_dir, save_top_k=3, verbose=1) w.save_function = mock_save_function for loss in losses: w.on_epoch_end(0, logs={'val_loss': loss}) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3' assert '_ckpt_epoch_0_v2.ckpt' in file_lists assert '_ckpt_epoch_0_v1.ckpt' in file_lists assert '_ckpt_epoch_0.ckpt' in file_lists
def test_v1_5_0_model_checkpoint_save_checkpoint(): model_ckpt = ModelCheckpoint() model_ckpt.save_function = lambda *_, **__: None with pytest.deprecated_call( match="ModelCheckpoint.save_checkpoint` signature has changed"): model_ckpt.save_checkpoint(Trainer(), object())
def test_model_checkpoint_options(tmpdir): """Test ModelCheckpoint options.""" def mock_save_function(filepath): open(filepath, 'a').close() hparams = tutils.get_default_hparams() _ = LightningTestModel(hparams) # simulated losses save_dir = os.path.join(tmpdir, '1') os.mkdir(save_dir) losses = [10, 9, 2.8, 5, 2.5] # ----------------- # CASE K=-1 (all) checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == len( losses), "Should save all models when save_top_k=-1" # verify correct naming for fname in { 'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt', 'epoch=0.ckpt' }: assert fname in file_lists save_dir = os.path.join(tmpdir, '2') os.mkdir(save_dir) # ----------------- # CASE K=0 (none) checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = os.listdir(save_dir) assert len(file_lists) == 0, "Should save 0 models when save_top_k=0" save_dir = os.path.join(tmpdir, '3') os.mkdir(save_dir) # ----------------- # CASE K=1 (2.5, epoch 4) checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix_') checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 1, "Should save 1 model when save_top_k=1" assert 'test_prefix_epoch=4.ckpt' in file_lists save_dir = os.path.join(tmpdir, '4') os.mkdir(save_dir) # ----------------- # CASE K=2 (2.5 epoch 4, 2.8 epoch 2) # make sure other files don't get deleted checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1) open(f"{save_dir}/other_file.ckpt", 'a').close() checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2' for fname in {'epoch=4.ckpt', 'epoch=2.ckpt', 'other_file.ckpt'}: assert fname in file_lists save_dir = os.path.join(tmpdir, '5') os.mkdir(save_dir) # ----------------- # CASE K=4 (save all 4 base) # multiple checkpoints within same epoch checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for loss in losses: trainer.current_epoch = 0 trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len( file_lists ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch' save_dir = os.path.join(tmpdir, '6') os.mkdir(save_dir) # ----------------- # CASE K=3 (save the 2nd, 3rd, 4th model) # multiple checkpoints within same epoch checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for loss in losses: trainer.current_epoch = 0 trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3' for fname in {'epoch=0.ckpt', 'epoch=0.ckpt', 'epoch=0.ckpt'}: assert fname in file_lists