def test_with_time(self): """ Tests that keep_serialized_model_every_num_seconds parameter causes a checkpoint to be saved after enough time has elapsed between epochs. """ num_to_keep = 10 num_epochs = 30 target = list(range(num_epochs - num_to_keep, num_epochs)) pauses = [5, 18, 26] target = sorted(set(target + pauses)) checkpointer = Checkpointer( serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=num_to_keep, keep_serialized_model_every_num_seconds=1, ) for e in range(num_epochs): if e in pauses: time.sleep(2) checkpointer.save_checkpoint( epoch=e, trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}), is_best_so_far=False, ) models, training = self.retrieve_and_delete_saved() assert models == training == target
def test_keep_zero(self): checkpointer = Checkpointer( serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=0 ) for e in range(10): checkpointer.save_checkpoint( epoch=e, trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}), is_best_so_far=True, ) files = os.listdir(self.TEST_DIR) assert "model_state_epoch_1.th" not in files assert "training_state_epoch_1.th" not in files
def test_default(self): """ Tests that the default behavior keeps just the last 2 checkpoints. """ default_num_to_keep = 2 num_epochs = 30 target = list(range(num_epochs - default_num_to_keep, num_epochs)) checkpointer = Checkpointer(serialization_dir=self.TEST_DIR) for e in range(num_epochs): checkpointer.save_checkpoint( epoch=e, trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}), is_best_so_far=False, ) models, training = self.retrieve_and_delete_saved() assert models == training == target