def test_monitor_checkpoint(self, tmpdir, test_engine, optimizer_state): path = Path(tmpdir.join("path/to/monitor_checkpoints/")) checkpoint = MonitorCheckpoint(dir_path=path, max_saves=3, monitor='val_loss', optimizer_state=optimizer_state) checkpoint.attach(test_engine) checkpoint.start(test_engine.state) decreasing_seq = list(range(30))[::-1] for i in range(1, len(decreasing_seq), 2): decreasing_seq[i] = 100 for epoch, val_loss in enumerate(decreasing_seq, 1): checkpoint_step_epoch(checkpoint, test_engine, epoch, val_loss) expected_path = path / f'model-{epoch:03d}-{val_loss:.6f}.pth' if val_loss != 100: assert check_checkpoint(path, test_engine, epoch, val_loss, optimizer_state=optimizer_state) else: assert not expected_path.exists() assert len(list(path.glob('*.pth'))) == 3
def test_checkpoint_exceptions(self, tmpdir, test_engine, recwarn): path = Path(tmpdir.join("path/to/exception_checkpoints/")) with pytest.raises(ValueError): Checkpoint(dir_path=path, max_saves=-3) path.mkdir(parents=True) Checkpoint(dir_path=path) assert len(recwarn) == 1 warn = recwarn.pop() assert f"Directory '{path}' already exists" == str(warn.message) with pytest.raises(ValueError): MonitorCheckpoint(dir_path=path, monitor='qwerty') checkpoint = MonitorCheckpoint(dir_path=path, monitor='train_loss') checkpoint.attach(test_engine) with pytest.raises(ValueError): checkpoint.epoch_complete(test_engine.state)