示例#1
0
 def test_keep_zero(self):
     checkpointer = Checkpointer(serialization_dir=self.TEST_DIR,
                                 keep_most_recent_by_count=0)
     for epochs_completed in range(5):
         state = {
             "epochs_completed": epochs_completed,
             "batches_in_epoch_completed": 0
         }
         checkpointer.maybe_save_checkpoint(
             FakeTrainer(model_state=state, training_state=state),
             epochs_completed, 0)
     files = os.listdir(self.TEST_DIR)
     assert not any("model_state_" in x for x in files)
     assert not any("training_state_" in x for x in files)
示例#2
0
 def test_with_time(self):
     num_epochs = 30
     pauses = [5, 18, 26]
     target = [(e, 0) for e in pauses]
     checkpointer = Checkpointer(
         serialization_dir=self.TEST_DIR,
         save_completed_epochs=False,
         save_every_num_seconds=1,
         keep_most_recent_by_count=3,
     )
     for e in range(num_epochs):
         if e in pauses:
             time.sleep(2)
         state = {"epochs_completed": e, "batches_in_epoch_completed": 0}
         checkpointer.maybe_save_checkpoint(
             trainer=FakeTrainer(model_state=state, training_state=state),
             num_epochs_completed=e,
             num_batches_in_epoch_completed=0,
         )
     models, training = self.retrieve_and_delete_saved()
     assert models == training == target
示例#3
0
    def test_default(self):
        """
        Tests that the default behavior keeps just the last 2 checkpoints.
        """
        default_num_to_keep = 2
        num_epochs = 5
        target = [(e, 0)
                  for e in range(num_epochs - default_num_to_keep, num_epochs)]

        checkpointer = Checkpointer(serialization_dir=self.TEST_DIR)
        for epochs_completed in range(num_epochs):
            for batches_completed in [0, 5, 10]:
                state = {
                    "epochs_completed": epochs_completed,
                    "batches_in_epoch_completed": batches_completed,
                }
                checkpointer.maybe_save_checkpoint(
                    FakeTrainer(model_state=state, training_state=state),
                    epochs_completed,
                    batches_completed,
                )
        models, training = self.retrieve_and_delete_saved()
        assert models == training == target