示例#1
0
def test_checkpoint_repeated_strategy_tmpdir(tmpdir):
    """
    This test validates that the checkpoint can be called when provided to callacks list
    """

    os.environ['PL_DEV_DEBUG'] = '1'

    checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                          filepath=os.path.join(
                                              tmpdir, "{epoch:02d}"))

    class ExtendedBoringModel(BoringModel):
        def validation_step(self, batch, batch_idx):
            output = self.layer(batch)
            loss = self.loss(batch, output)
            return {"val_loss": loss}

    model = ExtendedBoringModel()
    model.validation_step_end = None
    model.validation_epoch_end = None
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      limit_train_batches=2,
                      limit_val_batches=2,
                      limit_test_batches=2,
                      callbacks=[checkpoint_callback])

    trainer.fit(model)
    assert sorted(os.listdir(tmpdir)) == sorted(
        ['epoch=00.ckpt', 'lightning_logs'])
    path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs')
    assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0'])

    def get_last_checkpoint():
        ckpts = os.listdir(tmpdir)
        ckpts_map = {
            int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x)
            for x in ckpts if "epoch" in x
        }
        num_ckpts = len(ckpts_map) - 1
        return ckpts_map[num_ckpts]

    for idx in range(1, 5):

        # load from checkpoint
        chk = get_last_checkpoint()
        model = BoringModel.load_from_checkpoint(chk)
        trainer = pl.Trainer(default_root_dir=tmpdir,
                             max_epochs=1,
                             limit_train_batches=2,
                             limit_val_batches=2,
                             limit_test_batches=2,
                             resume_from_checkpoint=chk)

        trainer.fit(model)
        trainer.test(model)
        assert sorted(os.listdir(tmpdir)) == sorted(
            ['epoch=00.ckpt', 'lightning_logs'])
        assert sorted(os.listdir(path_to_lightning_logs)) == sorted(
            [f'version_{i}' for i in range(idx + 1)])
def test_checkpoint_repeated_strategy(enable_pl_optimizer, tmpdir):
    """
    This test validates that the checkpoint can be called when provided to callacks list
    """

    checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                          dirpath=tmpdir,
                                          filename="{epoch:02d}")

    class ExtendedBoringModel(BoringModel):
        def validation_step(self, batch, batch_idx):
            output = self.layer(batch)
            loss = self.loss(batch, output)
            return {"val_loss": loss}

    model = ExtendedBoringModel()
    model.validation_step_end = None
    model.validation_epoch_end = None
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        callbacks=[checkpoint_callback],
        enable_pl_optimizer=enable_pl_optimizer,
    )

    trainer.fit(model)
    assert os.listdir(tmpdir) == ['epoch=00.ckpt']

    def get_last_checkpoint():
        ckpts = os.listdir(tmpdir)
        ckpts_map = {
            int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x)
            for x in ckpts if "epoch" in x
        }
        num_ckpts = len(ckpts_map) - 1
        return ckpts_map[num_ckpts]

    for idx in range(1, 5):
        # load from checkpoint
        chk = get_last_checkpoint()
        model = BoringModel.load_from_checkpoint(chk)
        trainer = pl.Trainer(max_epochs=1,
                             limit_train_batches=2,
                             limit_val_batches=2,
                             limit_test_batches=2,
                             resume_from_checkpoint=chk,
                             enable_pl_optimizer=enable_pl_optimizer)
        trainer.fit(model)
        trainer.test(model)

        assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']"
示例#3
0
def test_checkpoint_repeated_strategy_extended(tmpdir):
    """
    This test validates checkpoint can be called several times without
    increasing internally its global step if nothing run.
    """

    os.environ['PL_DEV_DEBUG'] = '1'

    class ExtendedBoringModel(BoringModel):
        def validation_step(self, batch, batch_idx):
            output = self.layer(batch)
            loss = self.loss(batch, output)
            return {"val_loss": loss}

    model = ExtendedBoringModel()
    model.validation_step_end = None
    model.validation_epoch_end = None
    trainer = pl.Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
    )

    assert trainer.checkpoint_connector.has_trained is not True
    assert trainer.current_epoch == 0
    trainer.fit(model)
    assert trainer.checkpoint_connector.has_trained is True
    assert trainer.global_step == 2
    assert trainer.current_epoch == 0
    trainer.test(model)
    assert trainer.current_epoch == 0
    assert str(os.listdir(osp.join(tmpdir,
                                   'lightning_logs'))) == "['version_0']"

    def get_last_checkpoint():
        logs_dir = osp.join(tmpdir, 'lightning_logs')
        versions = os.listdir(logs_dir)
        versions.sort()

        last_version = versions[-1]
        ckpt_dir = osp.join(logs_dir, last_version, "checkpoints")

        ckpts = os.listdir(ckpt_dir)
        ckpts.sort()

        return osp.join(ckpt_dir, ckpts[-1])

    def assert_checkpoint_content():
        chk = pl_load(get_last_checkpoint())
        assert chk["epoch"] == 1
        assert chk["global_step"] == 2

    assert_checkpoint_content()

    for idx in range(1, 5):
        # load from checkpoint
        chk = get_last_checkpoint()
        assert_checkpoint_content()
        model = BoringModel.load_from_checkpoint(chk)
        trainer = pl.Trainer(default_root_dir=tmpdir,
                             max_epochs=1,
                             limit_train_batches=2,
                             limit_val_batches=2,
                             limit_test_batches=2,
                             resume_from_checkpoint=chk)
        assert trainer.checkpoint_connector.has_trained is not True
        assert trainer.global_step == 0
        trainer.test(model)
        assert trainer.global_step == 2
        trainer.fit(model)
        assert trainer.global_step == 2
        assert trainer.checkpoint_connector.has_trained is not True
        lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
        assert sorted(os.listdir(lightning_logs_path)) == [
            f"version_{i}" for i in range(idx + 1)
        ]