예제 #1
0
    def test_monitor_checkpoint(self, tmpdir, test_engine, optimizer_state):
        path = Path(tmpdir.join("path/to/monitor_checkpoints/"))
        checkpoint = MonitorCheckpoint(dir_path=path,
                                       max_saves=3,
                                       monitor='val_loss',
                                       optimizer_state=optimizer_state)
        checkpoint.attach(test_engine)
        checkpoint.start(test_engine.state)

        decreasing_seq = list(range(30))[::-1]
        for i in range(1, len(decreasing_seq), 2):
            decreasing_seq[i] = 100

        for epoch, val_loss in enumerate(decreasing_seq, 1):
            checkpoint_step_epoch(checkpoint, test_engine, epoch, val_loss)
            expected_path = path / f'model-{epoch:03d}-{val_loss:.6f}.pth'
            if val_loss != 100:
                assert check_checkpoint(path,
                                        test_engine,
                                        epoch,
                                        val_loss,
                                        optimizer_state=optimizer_state)
            else:
                assert not expected_path.exists()

        assert len(list(path.glob('*.pth'))) == 3
예제 #2
0
    def test_checkpoint_exceptions(self, tmpdir, test_engine, recwarn):
        path = Path(tmpdir.join("path/to/exception_checkpoints/"))
        with pytest.raises(ValueError):
            Checkpoint(dir_path=path, max_saves=-3)

        path.mkdir(parents=True)
        Checkpoint(dir_path=path)
        assert len(recwarn) == 1
        warn = recwarn.pop()
        assert f"Directory '{path}' already exists" == str(warn.message)

        with pytest.raises(ValueError):
            MonitorCheckpoint(dir_path=path, monitor='qwerty')

        checkpoint = MonitorCheckpoint(dir_path=path, monitor='train_loss')
        checkpoint.attach(test_engine)
        with pytest.raises(ValueError):
            checkpoint.epoch_complete(test_engine.state)