def test_checkpoint_repeated_strategy_tmpdir(tmpdir): """ This test validates that the checkpoint can be called when provided to callacks list """ os.environ['PL_DEV_DEBUG'] = '1' checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=os.path.join( tmpdir, "{epoch:02d}")) class ExtendedBoringModel(BoringModel): def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) return {"val_loss": loss} model = ExtendedBoringModel() model.validation_step_end = None model.validation_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, callbacks=[checkpoint_callback]) trainer.fit(model) assert sorted(os.listdir(tmpdir)) == sorted( ['epoch=00.ckpt', 'lightning_logs']) path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs') assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0']) def get_last_checkpoint(): ckpts = os.listdir(tmpdir) ckpts_map = { int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x } num_ckpts = len(ckpts_map) - 1 return ckpts_map[num_ckpts] for idx in range(1, 5): # load from checkpoint chk = get_last_checkpoint() model = BoringModel.load_from_checkpoint(chk) trainer = pl.Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, resume_from_checkpoint=chk) trainer.fit(model) trainer.test(model) assert sorted(os.listdir(tmpdir)) == sorted( ['epoch=00.ckpt', 'lightning_logs']) assert sorted(os.listdir(path_to_lightning_logs)) == sorted( [f'version_{i}' for i in range(idx + 1)])
def test_checkpoint_repeated_strategy(enable_pl_optimizer, tmpdir): """ This test validates that the checkpoint can be called when provided to callacks list """ checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}") class ExtendedBoringModel(BoringModel): def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) return {"val_loss": loss} model = ExtendedBoringModel() model.validation_step_end = None model.validation_epoch_end = None trainer = Trainer( max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, callbacks=[checkpoint_callback], enable_pl_optimizer=enable_pl_optimizer, ) trainer.fit(model) assert os.listdir(tmpdir) == ['epoch=00.ckpt'] def get_last_checkpoint(): ckpts = os.listdir(tmpdir) ckpts_map = { int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x } num_ckpts = len(ckpts_map) - 1 return ckpts_map[num_ckpts] for idx in range(1, 5): # load from checkpoint chk = get_last_checkpoint() model = BoringModel.load_from_checkpoint(chk) trainer = pl.Trainer(max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, resume_from_checkpoint=chk, enable_pl_optimizer=enable_pl_optimizer) trainer.fit(model) trainer.test(model) assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']"
def test_checkpoint_repeated_strategy_extended(tmpdir): """ This test validates checkpoint can be called several times without increasing internally its global step if nothing run. """ os.environ['PL_DEV_DEBUG'] = '1' class ExtendedBoringModel(BoringModel): def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) return {"val_loss": loss} model = ExtendedBoringModel() model.validation_step_end = None model.validation_epoch_end = None trainer = pl.Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, ) assert trainer.checkpoint_connector.has_trained is not True assert trainer.current_epoch == 0 trainer.fit(model) assert trainer.checkpoint_connector.has_trained is True assert trainer.global_step == 2 assert trainer.current_epoch == 0 trainer.test(model) assert trainer.current_epoch == 0 assert str(os.listdir(osp.join(tmpdir, 'lightning_logs'))) == "['version_0']" def get_last_checkpoint(): logs_dir = osp.join(tmpdir, 'lightning_logs') versions = os.listdir(logs_dir) versions.sort() last_version = versions[-1] ckpt_dir = osp.join(logs_dir, last_version, "checkpoints") ckpts = os.listdir(ckpt_dir) ckpts.sort() return osp.join(ckpt_dir, ckpts[-1]) def assert_checkpoint_content(): chk = pl_load(get_last_checkpoint()) assert chk["epoch"] == 1 assert chk["global_step"] == 2 assert_checkpoint_content() for idx in range(1, 5): # load from checkpoint chk = get_last_checkpoint() assert_checkpoint_content() model = BoringModel.load_from_checkpoint(chk) trainer = pl.Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, resume_from_checkpoint=chk) assert trainer.checkpoint_connector.has_trained is not True assert trainer.global_step == 0 trainer.test(model) assert trainer.global_step == 2 trainer.fit(model) assert trainer.global_step == 2 assert trainer.checkpoint_connector.has_trained is not True lightning_logs_path = osp.join(tmpdir, 'lightning_logs') assert sorted(os.listdir(lightning_logs_path)) == [ f"version_{i}" for i in range(idx + 1) ]