def test_with_time(self):
     """
     Tests that keep_serialized_model_every_num_seconds parameter causes a checkpoint to be saved
     after enough time has elapsed between epochs.
     """
     num_to_keep = 10
     num_epochs = 30
     target = list(range(num_epochs - num_to_keep, num_epochs))
     pauses = [5, 18, 26]
     target = sorted(set(target + pauses))
     checkpointer = Checkpointer(
         serialization_dir=self.TEST_DIR,
         num_serialized_models_to_keep=num_to_keep,
         keep_serialized_model_every_num_seconds=1,
     )
     for e in range(num_epochs):
         if e in pauses:
             time.sleep(2)
         checkpointer.save_checkpoint(
             epoch=e,
             trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}),
             is_best_so_far=False,
         )
     models, training = self.retrieve_and_delete_saved()
     assert models == training == target
 def test_keep_zero(self):
     checkpointer = Checkpointer(
         serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=0
     )
     for e in range(10):
         checkpointer.save_checkpoint(
             epoch=e,
             trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}),
             is_best_so_far=True,
         )
     files = os.listdir(self.TEST_DIR)
     assert "model_state_epoch_1.th" not in files
     assert "training_state_epoch_1.th" not in files
    def test_default(self):
        """
        Tests that the default behavior keeps just the last 2 checkpoints.
        """
        default_num_to_keep = 2
        num_epochs = 30
        target = list(range(num_epochs - default_num_to_keep, num_epochs))

        checkpointer = Checkpointer(serialization_dir=self.TEST_DIR)

        for e in range(num_epochs):
            checkpointer.save_checkpoint(
                epoch=e,
                trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}),
                is_best_so_far=False,
            )
        models, training = self.retrieve_and_delete_saved()
        assert models == training == target