class TrainerCallbackConfigMixin(ABC):

    def __init__(self):
        # this is just a summary on variables used in this abstract class,
        #  the proper values/initialisation should be done in child class
        self.default_save_path = None
        self.save_checkpoint = None
        self.slurm_job_id = None

    def configure_checkpoint_callback(self):
        """
        Weight path set in this priority:
        Checkpoint_callback's path (if passed in).
        User provided weights_saved_path
        Otherwise use os.getcwd()
        """
        if self.checkpoint_callback is True:
            # init a default one
            if self.logger is not None:
                save_dir = (getattr(self.logger, 'save_dir', None) or
                            getattr(self.logger, '_save_dir', None) or
                            self.default_save_path)
                ckpt_path = os.path.join(
                    save_dir,
                    self.logger.name,
                    f'version_{self.logger.version}',
                    "checkpoints"
                )
            else:
                ckpt_path = os.path.join(self.default_save_path, "checkpoints")

            self.checkpoint_callback = ModelCheckpoint(
                filepath=ckpt_path
            )
        elif self.checkpoint_callback is False:
            self.checkpoint_callback = None

        if self.checkpoint_callback:
            # set the path for the callbacks
            self.checkpoint_callback.save_function = self.save_checkpoint

            # if checkpoint callback used, then override the weights path
            self.weights_save_path = self.checkpoint_callback.filepath

            # link to the trainer
            self.checkpoint_callback.set_trainer(self)

        # if weights_save_path is still none here, set to current working dir
        if self.weights_save_path is None:
            self.weights_save_path = self.default_save_path

    def configure_early_stopping(self, early_stop_callback):
        if early_stop_callback is True:
            self.early_stop_callback = EarlyStopping(
                monitor='val_loss',
                patience=3,
                strict=True,
                verbose=True,
                mode='min'
            )
            self.enable_early_stop = True
        elif early_stop_callback is None:
            self.early_stop_callback = EarlyStopping(
                monitor='val_loss',
                patience=3,
                strict=False,
                verbose=False,
                mode='min'
            )
            self.enable_early_stop = True
        elif not early_stop_callback:
            self.early_stop_callback = None
            self.enable_early_stop = False
        else:
            self.early_stop_callback = early_stop_callback
            self.enable_early_stop = True

        if self.early_stop_callback is not None:
            self.early_stop_callback.set_trainer(self)
Ejemplo n.º 2
0
def test_model_checkpoint_options(tmp_path):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath):
        open(filepath, 'a').close()

    hparams = tutils.get_hparams()
    _ = LightningTestModel(hparams)

    # simulated losses
    save_dir = tmp_path / "1"
    save_dir.mkdir()
    losses = [10, 9, 2.8, 5, 2.5]

    # -----------------
    # CASE K=-1  (all)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()
    checkpoint_callback.set_trainer(trainer)

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        checkpoint_callback._trainer.current_epoch = i
        checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end()

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == len(
        losses), "Should save all models when save_top_k=-1"

    # verify correct naming
    for i in range(0, len(losses)):
        assert f'_ckpt_epoch_{i}.ckpt' in file_lists

    save_dir = tmp_path / "2"
    save_dir.mkdir()

    # -----------------
    # CASE K=0 (none)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()
    checkpoint_callback.set_trainer(trainer)

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        checkpoint_callback._trainer.current_epoch = i
        checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end()

    file_lists = os.listdir(save_dir)

    assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"

    save_dir = tmp_path / "3"
    save_dir.mkdir()

    # -----------------
    # CASE K=1 (2.5, epoch 4)
    checkpoint_callback = ModelCheckpoint(save_dir,
                                          save_top_k=1,
                                          verbose=1,
                                          prefix='test_prefix')
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()
    checkpoint_callback.set_trainer(trainer)

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        checkpoint_callback._trainer.current_epoch = i
        checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end()

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
    assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists

    save_dir = tmp_path / "4"
    save_dir.mkdir()

    # -----------------
    # CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
    # make sure other files don't get deleted

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
    open(f'{save_dir}/other_file.ckpt', 'a').close()
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()
    checkpoint_callback.set_trainer(trainer)

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        checkpoint_callback._trainer.current_epoch = i
        checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end()

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
    assert '_ckpt_epoch_4.ckpt' in file_lists
    assert '_ckpt_epoch_2.ckpt' in file_lists
    assert 'other_file.ckpt' in file_lists

    save_dir = tmp_path / "5"
    save_dir.mkdir()

    # -----------------
    # CASE K=4 (save all 4 models)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()
    checkpoint_callback.set_trainer(trainer)

    # emulate callback's calls during the training
    for loss in losses:
        checkpoint_callback._trainer.current_epoch = 0
        checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end()

    file_lists = set(os.listdir(save_dir))

    assert len(
        file_lists
    ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'

    save_dir = tmp_path / "6"
    save_dir.mkdir()

    # -----------------
    # CASE K=3 (save the 2nd, 3rd, 4th model)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()
    checkpoint_callback.set_trainer(trainer)

    # emulate callback's calls during the training
    for loss in losses:
        checkpoint_callback._trainer.current_epoch = 0
        checkpoint_callback._trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end()

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
    assert '_ckpt_epoch_0_v2.ckpt' in file_lists
    assert '_ckpt_epoch_0_v1.ckpt' in file_lists
    assert '_ckpt_epoch_0.ckpt' in file_lists