Esempio n. 1
0
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix,
                                  expected_files):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath, *args):
        open(filepath, 'a').close()

    # simulated losses
    losses = [10, 9, 2.8, 5, 2.5]

    checkpoint_callback = ModelCheckpoint(tmpdir,
                                          save_top_k=save_top_k,
                                          save_last=save_last,
                                          prefix=file_prefix,
                                          verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(tmpdir))

    assert len(file_lists) == len(expected_files), \
        "Should save %i models when save_top_k=%i" % (len(expected_files), save_top_k)

    # verify correct naming
    for fname in expected_files:
        assert fname in file_lists
Esempio n. 2
0
def test_transfer_batch_hook():
    class CustomBatch:
        def __init__(self, data):
            self.samples = data[0]
            self.targets = data[1]

    class CurrentTestModel(EvalModelTemplate):

        hook_called = False

        def transfer_batch_to_device(self, data, device):
            self.hook_called = True
            if isinstance(data, CustomBatch):
                data.samples = data.samples.to(device)
                data.targets = data.targets.to(device)
            else:
                data = super().transfer_batch_to_device(data, device)
            return data

    model = CurrentTestModel()
    batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1,
                                                        dtype=torch.long)))

    trainer = Trainer()
    # running .fit() would require us to implement custom data loaders, we mock the model reference instead
    trainer.get_model = MagicMock(return_value=model)
    batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
    expected = torch.device('cuda', 0)
    assert model.hook_called
    assert batch_gpu.samples.device == batch_gpu.targets.device == expected
Esempio n. 3
0
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix,
                                  expected_files):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath, *args):
        open(filepath, 'a').close()

    # simulated losses
    losses = [10, 9, 2.8, 5, 2.5]

    checkpoint_callback = ModelCheckpoint(tmpdir,
                                          monitor='checkpoint_on',
                                          save_top_k=save_top_k,
                                          save_last=save_last,
                                          prefix=file_prefix,
                                          verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.logger_connector.callback_metrics = {
            'checkpoint_on': torch.tensor(loss)
        }
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(tmpdir))

    assert len(file_lists) == len(expected_files), (
        f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}"
    )

    # verify correct naming
    for fname in expected_files:
        assert fname in file_lists
def test_dm_transfer_batch_to_device(tmpdir):
    class CustomBatch:

        def __init__(self, data):
            self.samples = data[0]
            self.targets = data[1]

    class CurrentTestDM(LightningDataModule):

        hook_called = False

        def transfer_batch_to_device(self, data, device):
            self.hook_called = True
            if isinstance(data, CustomBatch):
                data.samples = data.samples.to(device)
                data.targets = data.targets.to(device)
            else:
                data = super().transfer_batch_to_device(data, device)
            return data

    model = EvalModelTemplate()
    dm = CurrentTestDM()
    batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))

    trainer = Trainer(gpus=1)
    # running .fit() would require us to implement custom data loaders, we mock the model reference instead
    trainer.get_model = MagicMock(return_value=model)
    if is_overridden('transfer_batch_to_device', dm):
        model.transfer_batch_to_device = dm.transfer_batch_to_device

    trainer.accelerator_backend = GPUBackend(trainer)
    batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
    expected = torch.device('cuda', 0)
    assert dm.hook_called
    assert batch_gpu.samples.device == batch_gpu.targets.device == expected
Esempio n. 5
0
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1):
    model = SwaTestModel(batchnorm=batchnorm)
    swa_start = 2
    max_epochs = 5
    swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
    assert swa_callback.update_parameters_calls == 0
    assert swa_callback.transfer_weights_calls == 0

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=max_epochs,
        limit_train_batches=5,
        limit_val_batches=0,
        callbacks=[swa_callback],
        accumulate_grad_batches=2,
        accelerator=accelerator,
        gpus=gpus,
        num_processes=num_processes
    )
    trainer.fit(model)

    # check the model is the expected
    assert trainer.get_model() == model
Esempio n. 6
0
def test_v1_4_0_deprecated_trainer_methods():
    with pytest.deprecated_call(match='will be removed in v1.4'):
        trainer = Trainer()
        _ = trainer.get_model()
    assert trainer.get_model() == trainer.lightning_module
Esempio n. 7
0
def test_model_checkpoint_options(tmp_path):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath):
        open(filepath, 'a').close()

    hparams = tutils.get_hparams()
    _ = LightningTestModel(hparams)

    # simulated losses
    save_dir = tmp_path / "1"
    save_dir.mkdir()
    losses = [10, 9, 2.8, 5, 2.5]

    # -----------------
    # CASE K=-1  (all)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == len(
        losses), "Should save all models when save_top_k=-1"

    # verify correct naming
    for i in range(0, len(losses)):
        assert f"_ckpt_epoch_{i}.ckpt" in file_lists

    save_dir = tmp_path / "2"
    save_dir.mkdir()

    # -----------------
    # CASE K=0 (none)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = os.listdir(save_dir)

    assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"

    save_dir = tmp_path / "3"
    save_dir.mkdir()

    # -----------------
    # CASE K=1 (2.5, epoch 4)
    checkpoint_callback = ModelCheckpoint(save_dir,
                                          save_top_k=1,
                                          verbose=1,
                                          prefix='test_prefix')
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
    assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists

    save_dir = tmp_path / "4"
    save_dir.mkdir()

    # -----------------
    # CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
    # make sure other files don't get deleted

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
    open(f"{save_dir}/other_file.ckpt", 'a').close()
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
    assert '_ckpt_epoch_4.ckpt' in file_lists
    assert '_ckpt_epoch_2.ckpt' in file_lists
    assert 'other_file.ckpt' in file_lists

    save_dir = tmp_path / "5"
    save_dir.mkdir()

    # -----------------
    # CASE K=4 (save all 4 models)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for loss in losses:
        trainer.current_epoch = 0
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(
        file_lists
    ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'

    save_dir = tmp_path / "6"
    save_dir.mkdir()

    # -----------------
    # CASE K=3 (save the 2nd, 3rd, 4th model)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for loss in losses:
        trainer.current_epoch = 0
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
    assert '_ckpt_epoch_0_v2.ckpt' in file_lists
    assert '_ckpt_epoch_0_v1.ckpt' in file_lists
    assert '_ckpt_epoch_0.ckpt' in file_lists