Exemple #1
0
    def test_finalize_and_resume_file(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            checkpoint.finalize()
            original = deepcopy(self.trainer.model)
            pth_path = os.path.join(d, "simple_final.pth")
            self.assertTrue(PathManager.exists(pth_path))

            self._do_a_pass()

            after_a_pass = deepcopy(self.trainer.model)
            original_optimizer = deepcopy(self.trainer.optimizer)
            self.trainer.config.checkpoint.resume_file = pth_path

            with contextlib.redirect_stdout(StringIO()):
                checkpoint.load_state_dict()
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original.state_dict()))
            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    after_a_pass.state_dict()))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
Exemple #2
0
class CheckpointCallback(Callback):
    """Callback for executing different checkpoint requirements.
    """
    def __init__(self, config, trainer):
        """
        Attr:
            config(multimodelity_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self._checkpoint = Checkpoint(trainer)
        self.checkpoint_interval = self.config.training.checkpoint_interval

    @property
    def checkpoint(self):
        return self._checkpoint

    def on_init_start(self, **kwargs):
        self._checkpoint.load_state_dict()

    def on_update_end(self, **kwargs):
        if self.trainer.num_updates % self.checkpoint_interval == 0:
            logger.info("Checkpoint time. Saving a checkpoint.")
            # Consolidate the state dict of sharded optimizers
            consolidate_optim_state_dict(self.trainer.optimizer)
            self._checkpoint.save(
                self.trainer.num_updates,
                self.trainer.current_iteration,
                update_best=False,
            )

    def on_train_end(self, **kwargs):
        self._checkpoint.restore()
        self._checkpoint.finalize()