Example #1
0
    def test_checkpoint_scaler_loading(self):
        with mock_env_with_temp():
            original_scaler = deepcopy(self.trainer.scaler)

            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)

            self._do_a_fp16_pass()
            checkpoint.save(1000)
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = True
            self.trainer.config.checkpoint.reset.counts = True
            self.trainer.config.checkpoint.reset.fp16_scaler = True

            # Reset to make it same as the default grad scaler
            self.trainer.scaler = torch.cuda.amp.GradScaler()
            checkpoint.load_state_dict()
            self.assertTrue(
                compare_state_dicts(self.trainer.scaler.state_dict(),
                                    original_scaler.state_dict()))

            self._do_a_fp16_pass()
            checkpoint.save(2000)
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = True
            self.trainer.config.checkpoint.reset.counts = True
            self.trainer.config.checkpoint.reset.fp16_scaler = False

            # Reset again to make it same as the default grad scaler
            self.trainer.scaler = torch.cuda.amp.GradScaler()
            checkpoint.load_state_dict()
            self.assertFalse(
                compare_state_dicts(self.trainer.scaler.state_dict(),
                                    original_scaler.state_dict()))
Example #2
0
    def __init__(self, config, trainer):
        """
        Attr:
            config(multimodelity_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self._checkpoint = Checkpoint(trainer)
        self.checkpoint_interval = self.config.training.checkpoint_interval
Example #3
0
class CheckpointCallback(Callback):
    """Callback for executing different checkpoint requirements.
    """
    def __init__(self, config, trainer):
        """
        Attr:
            config(multimodelity_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self._checkpoint = Checkpoint(trainer)
        self.checkpoint_interval = self.config.training.checkpoint_interval

    @property
    def checkpoint(self):
        return self._checkpoint

    def on_init_start(self, **kwargs):
        self._checkpoint.load_state_dict()

    def on_update_end(self, **kwargs):
        if self.trainer.num_updates % self.checkpoint_interval == 0:
            logger.info("Checkpoint time. Saving a checkpoint.")
            # Consolidate the state dict of sharded optimizers
            consolidate_optim_state_dict(self.trainer.optimizer)
            self._checkpoint.save(
                self.trainer.num_updates,
                self.trainer.current_iteration,
                update_best=False,
            )

    def on_train_end(self, **kwargs):
        self._checkpoint.restore()
        self._checkpoint.finalize()
Example #4
0
    def test_finalize_and_resume_file(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            checkpoint.finalize()
            original = deepcopy(self.trainer.model)
            pth_path = os.path.join(d, "simple_final.pth")
            self.assertTrue(PathManager.exists(pth_path))

            self._do_a_pass()

            after_a_pass = deepcopy(self.trainer.model)
            original_optimizer = deepcopy(self.trainer.optimizer)
            self.trainer.config.checkpoint.resume_file = pth_path

            with contextlib.redirect_stdout(StringIO()):
                checkpoint.load_state_dict()
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original.state_dict()))
            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    after_a_pass.state_dict()))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
Example #5
0
    def test_pretrained_load(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            original_model = deepcopy(self.trainer.model)
            # Test with zoo now
            ret_load_pretrained_zoo = {
                "config": self.config.model_config,
                "checkpoint": deepcopy(self.trainer.model.state_dict()),
                "full_config": self.config,
            }

            checkpoint.save(2000)
            self.trainer.config.checkpoint.resume_file = os.path.join(
                d, "current.ckpt")
            self.trainer.config.checkpoint.resume_pretrained = True
            self.trainer.model = OnlyBase()
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(
                    self.trainer.model.base_test.state_dict(),
                    original_model.base.state_dict(),
                ))

            with patch(
                    "multimodelity.utils.checkpoint.load_pretrained_model",
                    return_value=ret_load_pretrained_zoo,
            ):
                self.trainer.config.checkpoint.resume_zoo = "random"
                self.trainer.config.checkpoint.resume_file = None
                self.trainer.model = OnlyBase()
                checkpoint.load_state_dict()

                self.assertTrue(
                    compare_state_dicts(
                        self.trainer.model.base_test.state_dict(),
                        original_model.base.state_dict(),
                    ))
Example #6
0
    def test_zoo_load(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()

            original_model = deepcopy(self.trainer.model)
            ret_load_pretrained_zoo = {
                "config": self.config.model_config,
                "checkpoint": deepcopy(self.trainer.model.state_dict()),
                "full_config": self.config,
            }

            self._do_a_pass()

            with patch(
                    "multimodelity.utils.checkpoint.load_pretrained_model",
                    return_value=ret_load_pretrained_zoo,
            ):
                self.trainer.config.checkpoint.resume_zoo = "random"
                with contextlib.redirect_stdout(StringIO()):
                    checkpoint.load_state_dict()
                self.assertTrue(
                    compare_state_dicts(self.trainer.model.state_dict(),
                                        original_model.state_dict()))

                # Now, test zoo override
                self.trainer.config.checkpoint.zoo_config_override = True
                SimpleModule.from_pretrained = Mock(
                    return_value=deepcopy(original_model))
                registry.register_model("simple")(SimpleModule)
                with contextlib.redirect_stdout(StringIO()):
                    checkpoint.load_state_dict()
                self.assertTrue(
                    compare_state_dicts(self.trainer.model.state_dict(),
                                        original_model.state_dict()))
Example #7
0
    def test_max_to_keep(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)

            ckpt_paths = []
            for indx in [2000, 3000, 4000, 5000, 6000]:
                self._do_a_pass()
                checkpoint.save(indx, update_best=False)

                ckpt_paths.append(
                    os.path.join(checkpoint.models_foldername,
                                 "model_%d.ckpt" % indx))
                self.assertTrue(os.path.exists(ckpt_paths[-1]))

            for indx, u in enumerate([7000, 8000, 9000, 10000, 11000]):
                self._do_a_pass()
                checkpoint.save(u, update_best=False)

                ckpt_paths.append(
                    os.path.join(checkpoint.models_foldername,
                                 "model_%d.ckpt" % u))
                self.assertTrue(os.path.exists(ckpt_paths[-1]))
                self.assertFalse(os.path.exists(ckpt_paths[indx]))
Example #8
0
    def test_resets(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()

            original_optimizer = deepcopy(self.trainer.optimizer)
            original_model = deepcopy(self.trainer.model)
            original_scaler = deepcopy(self.trainer.scaler)

            self.trainer.current_epoch = 3
            checkpoint.save(2000, update_best=True)
            self.trainer.current_epoch = 4
            # Test reset all
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.reset.all = True
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))

            self.assertTrue(
                compare_state_dicts(self.trainer.scaler.state_dict(),
                                    original_scaler.state_dict()))

            self.assertEqual(self.trainer.num_updates, 0)
            self.assertEqual(self.trainer.current_iteration, 0)
            self.assertEqual(self.trainer.current_epoch, 4)

            # Test reset_optimizer
            self._init_early_stopping(checkpoint)
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = True
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))

            self.assertEqual(self.trainer.num_updates, 2000)
            self.assertEqual(self.trainer.current_iteration, 2000)
            self.assertEqual(self.trainer.current_epoch, 3)

            self._init_early_stopping(checkpoint)
            # Test reset_counts
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = False
            self.trainer.config.checkpoint.reset.counts = True
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
            self.assertEqual(self.trainer.num_updates, 0)
            self.assertEqual(self.trainer.current_iteration, 0)
            self.assertEqual(self.trainer.current_epoch, 2)

            # Test with resume_best
            self._do_a_pass()
            checkpoint.save(3000)
            self._init_early_stopping(checkpoint)
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.resume_best = True
            self.trainer.config.checkpoint.reset.optimizer = True
            self.trainer.config.checkpoint.reset.counts = False
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))

            self.assertEqual(self.trainer.num_updates, 1000)
            self.assertEqual(self.trainer.current_iteration, 1000)
            self.assertEqual(self.trainer.current_epoch, 3)
Example #9
0
    def test_finalize_and_restore_from_it(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            original_model = deepcopy(self.trainer.model)
            self._do_a_pass()
            model_1500 = deepcopy(self.trainer.model)
            checkpoint.save(1500)

            swap = self.trainer.model
            self.trainer.model = original_model
            checkpoint.restore()
            # First test without best.ckpt
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))
            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    model_1500.state_dict()))

            self.trainer.model = swap

            self._do_a_pass()
            model_2000 = deepcopy(self.trainer.model)
            checkpoint.save(2000, update_best=True)

            self._do_a_pass()
            model_2500 = deepcopy(self.trainer.model)
            checkpoint.save(2500)

            checkpoint.restore()

            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))
            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    model_1500.state_dict()))
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    model_2000.state_dict()))
            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    model_2500.state_dict()))
Example #10
0
    def test_save_and_load_state_dict(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            # Test normal case
            checkpoint.save(1500)

            self.assertTrue(
                PathManager.exists(os.path.join(d, "models",
                                                "model_1500.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d,
                                                            "current.ckpt")))
            self.assertFalse(PathManager.exists(os.path.join(d, "best.ckpt")))
            os.remove(os.path.join(d, "models", "model_1500.ckpt"))
            os.remove(os.path.join(d, "current.ckpt"))

            best_model = deepcopy(self.trainer.model)
            best_optimizer = deepcopy(self.trainer.optimizer)
            # Test with update_best
            checkpoint.save(2000, update_best=True)

            self.assertTrue(
                PathManager.exists(os.path.join(d, "models",
                                                "model_2000.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d, "best.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d,
                                                            "current.ckpt")))

            self._do_a_pass()
            checkpoint.save(2500)

            # Test resume
            self.trainer.config.checkpoint.resume = True

            current_model = deepcopy(self.trainer.model)
            current_optimizer = deepcopy(self.trainer.optimizer)
            checkpoint.load_state_dict()

            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    best_model.state_dict()))
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    current_model.state_dict()))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer))

            base_0_weight_current = self.trainer.model.base[
                0].weight.data.clone()

            # Test resume_best
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.resume_best = True

            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    best_model.state_dict()))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer))
            base_0_weight_best = self.trainer.model.base[0].weight.data.clone()

            self.trainer.config.checkpoint.resume_best = False
            # Test distributed settings
            self.trainer.model = torch.nn.DataParallel(self.trainer.model)
            checkpoint.load_state_dict()

            weight_to_be_tested = self.trainer.model.module.base[0].weight
            weight_device = weight_to_be_tested.device

            self.assertTrue(
                torch.equal(weight_to_be_tested,
                            base_0_weight_current.to(weight_device)))
            self.assertFalse(
                torch.equal(weight_to_be_tested,
                            base_0_weight_best.to(weight_device)))
Example #11
0
 def test_save_config(self):
     with mock_env_with_temp() as d:
         Checkpoint(self.trainer)
         config = load_yaml(os.path.join(d, "config.yaml"))
         self.assertTrue(config == self.config)
         self.assertTrue(config == self.trainer.config)