Esempio n. 1
0
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix,
                                  expected_files):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath, *args):
        open(filepath, 'a').close()

    # simulated losses
    losses = [10, 9, 2.8, 5, 2.5]

    checkpoint_callback = ModelCheckpoint(tmpdir,
                                          save_top_k=save_top_k,
                                          save_last=save_last,
                                          prefix=file_prefix,
                                          verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(tmpdir))

    assert len(file_lists) == len(expected_files), \
        "Should save %i models when save_top_k=%i" % (len(expected_files), save_top_k)

    # verify correct naming
    for fname in expected_files:
        assert fname in file_lists
Esempio n. 2
0
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix,
                                  expected_files):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath, *args):
        open(filepath, 'a').close()

    # simulated losses
    losses = [10, 9, 2.8, 5, 2.5]

    checkpoint_callback = ModelCheckpoint(tmpdir,
                                          monitor='checkpoint_on',
                                          save_top_k=save_top_k,
                                          save_last=save_last,
                                          prefix=file_prefix,
                                          verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.logger_connector.callback_metrics = {
            'checkpoint_on': torch.tensor(loss)
        }
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(tmpdir))

    assert len(file_lists) == len(expected_files), (
        f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}"
    )

    # verify correct naming
    for fname in expected_files:
        assert fname in file_lists
Esempio n. 3
0
    def init_default_checkpoint_callback(self, checkpoint_callback):
        if checkpoint_callback is True:
            checkpoint_callback = ModelCheckpoint(filepath=None)
        elif checkpoint_callback is False:
            checkpoint_callback = None
        if checkpoint_callback:
            checkpoint_callback.save_function = self.trainer.save_checkpoint

        return checkpoint_callback
def test_v1_5_0_model_checkpoint_save_function():
    model_ckpt = ModelCheckpoint()
    with pytest.deprecated_call(
            match=
            "Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"
    ):
        model_ckpt.save_function = lambda *_, **__: None
    with pytest.deprecated_call(
            match=
            "Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"
    ):
        _ = model_ckpt.save_function
    def configure_checkpoint_callback(self, checkpoint_callback):
        if checkpoint_callback is True:
            # when no val step is defined, use 'loss' otherwise 'val_loss'
            train_step_only = not self.is_overridden('validation_step')
            monitor_key = 'loss' if train_step_only else 'val_loss'
            checkpoint_callback = ModelCheckpoint(filepath=None,
                                                  monitor=monitor_key)
        elif checkpoint_callback is False:
            checkpoint_callback = None

        if checkpoint_callback:
            checkpoint_callback.save_function = self.save_checkpoint

        return checkpoint_callback
Esempio n. 6
0
    def configure_checkpoint_callback(self, checkpoint_callback):
        """
        Weight path set in this priority:
        Checkpoint_callback's path (if passed in).
        User provided weights_saved_path
        Otherwise use os.getcwd()
        """
        if checkpoint_callback is True:
            # when no val step is defined, use 'loss' otherwise 'val_loss'
            train_step_only = not self.is_overridden('validation_step')
            monitor_key = 'loss' if train_step_only else 'val_loss'
            checkpoint_callback = ModelCheckpoint(filepath=None,
                                                  monitor=monitor_key)
        elif checkpoint_callback is False:
            checkpoint_callback = None

        if checkpoint_callback:
            checkpoint_callback.save_function = self.save_checkpoint

        # if weights_save_path is still none here, set to current working dir
        if self.weights_save_path is None:
            self.weights_save_path = self.default_root_dir

        return checkpoint_callback
Esempio n. 7
0
def test_model_checkpoint_options(tmp_path):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath):
        open(filepath, 'a').close()

    hparams = tutils.get_hparams()
    _ = LightningTestModel(hparams)

    # simulated losses
    save_dir = tmp_path / "1"
    save_dir.mkdir()
    losses = [10, 9, 2.8, 5, 2.5]

    # -----------------
    # CASE K=-1  (all)
    w = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
    w.save_function = mock_save_function
    for i, loss in enumerate(losses):
        w.on_epoch_end(i, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == len(
        losses), "Should save all models when save_top_k=-1"

    # verify correct naming
    for i in range(0, len(losses)):
        assert f'_ckpt_epoch_{i}.ckpt' in file_lists

    save_dir = tmp_path / "2"
    save_dir.mkdir()

    # -----------------
    # CASE K=0 (none)
    w = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
    w.save_function = mock_save_function
    for i, loss in enumerate(losses):
        w.on_epoch_end(i, logs={'val_loss': loss})

    file_lists = os.listdir(save_dir)

    assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"

    save_dir = tmp_path / "3"
    save_dir.mkdir()

    # -----------------
    # CASE K=1 (2.5, epoch 4)
    w = ModelCheckpoint(save_dir,
                        save_top_k=1,
                        verbose=1,
                        prefix='test_prefix')
    w.save_function = mock_save_function
    for i, loss in enumerate(losses):
        w.on_epoch_end(i, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
    assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists

    save_dir = tmp_path / "4"
    save_dir.mkdir()

    # -----------------
    # CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
    # make sure other files don't get deleted

    w = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
    open(f'{save_dir}/other_file.ckpt', 'a').close()
    w.save_function = mock_save_function
    for i, loss in enumerate(losses):
        w.on_epoch_end(i, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
    assert '_ckpt_epoch_4.ckpt' in file_lists
    assert '_ckpt_epoch_2.ckpt' in file_lists
    assert 'other_file.ckpt' in file_lists

    save_dir = tmp_path / "5"
    save_dir.mkdir()

    # -----------------
    # CASE K=4 (save all 4 models)
    # multiple checkpoints within same epoch

    w = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
    w.save_function = mock_save_function
    for loss in losses:
        w.on_epoch_end(0, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(
        file_lists
    ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'

    save_dir = tmp_path / "6"
    save_dir.mkdir()

    # -----------------
    # CASE K=3 (save the 2nd, 3rd, 4th model)
    # multiple checkpoints within same epoch

    w = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
    w.save_function = mock_save_function
    for loss in losses:
        w.on_epoch_end(0, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
    assert '_ckpt_epoch_0_v2.ckpt' in file_lists
    assert '_ckpt_epoch_0_v1.ckpt' in file_lists
    assert '_ckpt_epoch_0.ckpt' in file_lists
def test_v1_5_0_model_checkpoint_save_checkpoint():
    model_ckpt = ModelCheckpoint()
    model_ckpt.save_function = lambda *_, **__: None
    with pytest.deprecated_call(
            match="ModelCheckpoint.save_checkpoint` signature has changed"):
        model_ckpt.save_checkpoint(Trainer(), object())
def test_model_checkpoint_options(tmpdir):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath):
        open(filepath, 'a').close()

    hparams = tutils.get_default_hparams()
    _ = LightningTestModel(hparams)

    # simulated losses
    save_dir = os.path.join(tmpdir, '1')
    os.mkdir(save_dir)
    losses = [10, 9, 2.8, 5, 2.5]

    # -----------------
    # CASE K=-1  (all)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == len(
        losses), "Should save all models when save_top_k=-1"

    # verify correct naming
    for fname in {
            'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt',
            'epoch=0.ckpt'
    }:
        assert fname in file_lists

    save_dir = os.path.join(tmpdir, '2')
    os.mkdir(save_dir)

    # -----------------
    # CASE K=0 (none)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = os.listdir(save_dir)

    assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"

    save_dir = os.path.join(tmpdir, '3')
    os.mkdir(save_dir)

    # -----------------
    # CASE K=1 (2.5, epoch 4)
    checkpoint_callback = ModelCheckpoint(save_dir,
                                          save_top_k=1,
                                          verbose=1,
                                          prefix='test_prefix_')
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
    assert 'test_prefix_epoch=4.ckpt' in file_lists

    save_dir = os.path.join(tmpdir, '4')
    os.mkdir(save_dir)

    # -----------------
    # CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
    # make sure other files don't get deleted

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
    open(f"{save_dir}/other_file.ckpt", 'a').close()
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
    for fname in {'epoch=4.ckpt', 'epoch=2.ckpt', 'other_file.ckpt'}:
        assert fname in file_lists

    save_dir = os.path.join(tmpdir, '5')
    os.mkdir(save_dir)

    # -----------------
    # CASE K=4 (save all 4 base)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for loss in losses:
        trainer.current_epoch = 0
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(
        file_lists
    ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'

    save_dir = os.path.join(tmpdir, '6')
    os.mkdir(save_dir)

    # -----------------
    # CASE K=3 (save the 2nd, 3rd, 4th model)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for loss in losses:
        trainer.current_epoch = 0
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
    for fname in {'epoch=0.ckpt', 'epoch=0.ckpt', 'epoch=0.ckpt'}:
        assert fname in file_lists