Beispiel #1
0
    def test_finalize_and_resume_file(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            checkpoint.finalize()
            original = deepcopy(self.trainer.model)
            pth_path = os.path.join(d, "simple_final.pth")
            self.assertTrue(PathManager.exists(pth_path))

            self._do_a_pass()

            after_a_pass = deepcopy(self.trainer.model)
            original_optimizer = deepcopy(self.trainer.optimizer)
            self.trainer.config.checkpoint.resume_file = pth_path

            with contextlib.redirect_stdout(StringIO()):
                checkpoint.load_state_dict()
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original.state_dict()))
            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    after_a_pass.state_dict()))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
            # Keys will not be same as we just updated the model
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))
Beispiel #2
0
    def test_finalize_and_restore_from_it(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            original_model = deepcopy(self.trainer.model)
            self._do_a_pass()
            model_1500 = deepcopy(self.trainer.model)
            checkpoint.save(1500)

            swap = self.trainer.model
            self.trainer.model = original_model
            checkpoint.restore()
            # First test without best.ckpt
            self.assertTrue(
                compare_state_dicts(
                    self.trainer.model.state_dict(), original_model.state_dict()
                )
            )
            self.assertFalse(
                compare_state_dicts(
                    self.trainer.model.state_dict(), model_1500.state_dict()
                )
            )

            self.trainer.model = swap

            self._do_a_pass()
            model_2000 = deepcopy(self.trainer.model)
            checkpoint.save(2000, update_best=True)

            self._do_a_pass()
            model_2500 = deepcopy(self.trainer.model)
            checkpoint.save(2500)

            checkpoint.restore()

            self.assertFalse(
                compare_state_dicts(
                    self.trainer.model.state_dict(), original_model.state_dict()
                )
            )
            self.assertFalse(
                compare_state_dicts(
                    self.trainer.model.state_dict(), model_1500.state_dict()
                )
            )
            self.assertTrue(
                compare_state_dicts(
                    self.trainer.model.state_dict(), model_2000.state_dict()
                )
            )
            self.assertFalse(
                compare_state_dicts(
                    self.trainer.model.state_dict(), model_2500.state_dict()
                )
            )
Beispiel #3
0
    def __init__(self, config, trainer):
        """
        Attr:
            config(mmf_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self._checkpoint = Checkpoint(trainer)
        self.checkpoint_interval = self.config.training.checkpoint_interval
Beispiel #4
0
    def load_extras(self):
        self.writer.write("Torch version is: " + torch.__version__)
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_config = self.config.training

        early_stop_criteria = self.training_config.early_stop.criteria
        early_stop_minimize = self.training_config.early_stop.minimize
        early_stop_enabled = self.training_config.early_stop.enabled
        early_stop_patience = self.training_config.early_stop.patience

        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval
        self.max_updates = self.training_config.max_updates
        self.should_clip_gradients = self.training_config.clip_gradients
        self.max_epochs = self.training_config.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            early_stop_criteria,
            patience=early_stop_patience,
            minimize=early_stop_minimize,
            should_stop=early_stop_enabled,
        )
        self.current_epoch = 0
        self.current_iteration = 0
        self.num_updates = 0

        self.checkpoint.load_state_dict()

        self.not_debug = self.training_config.logger_level != "debug"

        self.lr_scheduler = None

        if self.training_config.lr_scheduler is True:
            self.lr_scheduler = build_scheduler(self.optimizer, self.config)

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = self.writer.log_dir
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.current_iteration)
Beispiel #5
0
    def test_pretrained_load(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            original_model = deepcopy(self.trainer.model)
            # Test with zoo now
            ret_load_pretrained_zoo = {
                "config": self.config.model_config,
                "checkpoint": deepcopy(self.trainer.model.state_dict()),
                "full_config": self.config,
            }

            checkpoint.save(2000)
            self.trainer.config.checkpoint.resume_file = os.path.join(d, "current.ckpt")
            self.trainer.config.checkpoint.resume_pretrained = True
            self.trainer.model = OnlyBase()
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(
                    self.trainer.model.base_test.state_dict(),
                    original_model.base.state_dict(),
                )
            )

            with patch(
                "mmf.utils.checkpoint.load_pretrained_model",
                return_value=ret_load_pretrained_zoo,
            ):
                self.trainer.config.checkpoint.resume_zoo = "random"
                self.trainer.config.checkpoint.resume_file = None
                self.trainer.model = OnlyBase()
                checkpoint.load_state_dict()

                self.assertTrue(
                    compare_state_dicts(
                        self.trainer.model.base_test.state_dict(),
                        original_model.base.state_dict(),
                    )
                )
Beispiel #6
0
    def test_zoo_load(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()

            original_model = deepcopy(self.trainer.model)
            ret_load_pretrained_zoo = {
                "config": self.config.model_config,
                "checkpoint": deepcopy(self.trainer.model.state_dict()),
                "full_config": self.config,
            }

            self._do_a_pass()

            with patch(
                "mmf.utils.checkpoint.load_pretrained_model",
                return_value=ret_load_pretrained_zoo,
            ):
                self.trainer.config.checkpoint.resume_zoo = "random"
                with contextlib.redirect_stdout(StringIO()):
                    checkpoint.load_state_dict()
                self.assertTrue(
                    compare_state_dicts(
                        self.trainer.model.state_dict(), original_model.state_dict()
                    )
                )

                # Now, test zoo override
                self.trainer.config.checkpoint.zoo_config_override = True
                SimpleModule.from_pretrained = Mock(
                    return_value=deepcopy(original_model)
                )
                registry.register_model("simple")(SimpleModule)
                with contextlib.redirect_stdout(StringIO()):
                    checkpoint.load_state_dict()
                self.assertTrue(
                    compare_state_dicts(
                        self.trainer.model.state_dict(), original_model.state_dict()
                    )
                )
Beispiel #7
0
    def test_checkpoint_scaler_loading(self):
        with mock_env_with_temp():
            original_scaler = deepcopy(self.trainer.scaler)

            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)

            self._do_a_fp16_pass()
            checkpoint.save(1000)
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = True
            self.trainer.config.checkpoint.reset.counts = True
            self.trainer.config.checkpoint.reset.fp16_scaler = True

            # Reset to make it same as the default grad scaler
            self.trainer.scaler = torch.cuda.amp.GradScaler()
            checkpoint.load_state_dict()
            self.assertTrue(
                compare_state_dicts(
                    self.trainer.scaler.state_dict(), original_scaler.state_dict()
                )
            )

            self._do_a_fp16_pass()
            checkpoint.save(2000)
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = True
            self.trainer.config.checkpoint.reset.counts = True
            self.trainer.config.checkpoint.reset.fp16_scaler = False

            # Reset again to make it same as the default grad scaler
            self.trainer.scaler = torch.cuda.amp.GradScaler()
            checkpoint.load_state_dict()
            self.assertFalse(
                compare_state_dicts(
                    self.trainer.scaler.state_dict(), original_scaler.state_dict()
                )
            )
Beispiel #8
0
    def test_max_to_keep(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)

            ckpt_paths = []
            for indx in [2000, 3000, 4000, 5000, 6000]:
                self._do_a_pass()
                checkpoint.save(indx, update_best=False)

                ckpt_paths.append(
                    os.path.join(checkpoint.models_foldername,
                                 "model_%d.ckpt" % indx))
                self.assertTrue(os.path.exists(ckpt_paths[-1]))

            for indx, u in enumerate([7000, 8000, 9000, 10000, 11000]):
                self._do_a_pass()
                checkpoint.save(u, update_best=False)

                ckpt_paths.append(
                    os.path.join(checkpoint.models_foldername,
                                 "model_%d.ckpt" % u))
                self.assertTrue(os.path.exists(ckpt_paths[-1]))
                self.assertFalse(os.path.exists(ckpt_paths[indx]))
Beispiel #9
0
    def test_resets(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()

            original_optimizer = deepcopy(self.trainer.optimizer)
            original_model = deepcopy(self.trainer.model)

            self.trainer.current_epoch = 3
            checkpoint.save(2000, update_best=True)
            self.trainer.current_epoch = 4
            # Test reset all
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.reset.all = True
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))
            self.assertEqual(self.trainer.num_updates, 0)
            self.assertEqual(self.trainer.current_iteration, 0)
            self.assertEqual(self.trainer.current_epoch, 4)

            # Test reset_optimizer
            self._init_early_stopping(checkpoint)
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = True
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))

            self.assertEqual(self.trainer.num_updates, 2000)
            self.assertEqual(self.trainer.current_iteration, 2000)
            self.assertEqual(self.trainer.current_epoch, 3)

            self._init_early_stopping(checkpoint)
            # Test reset_counts
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = False
            self.trainer.config.checkpoint.reset.counts = True
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))
            self.assertEqual(self.trainer.num_updates, 0)
            self.assertEqual(self.trainer.current_iteration, 0)
            self.assertEqual(self.trainer.current_epoch, 2)

            # Test with resume_best
            self._do_a_pass()
            checkpoint.save(3000)
            self._init_early_stopping(checkpoint)
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.resume_best = True
            self.trainer.config.checkpoint.reset.optimizer = True
            self.trainer.config.checkpoint.reset.counts = False
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))
            self.assertEqual(self.trainer.num_updates, 1000)
            self.assertEqual(self.trainer.current_iteration, 1000)
            self.assertEqual(self.trainer.current_epoch, 3)
Beispiel #10
0
    def test_save_and_load_state_dict(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            # Test normal case
            checkpoint.save(1500)

            self.assertTrue(
                PathManager.exists(os.path.join(d, "models",
                                                "model_1500.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d,
                                                            "current.ckpt")))
            self.assertFalse(PathManager.exists(os.path.join(d, "best.ckpt")))
            os.remove(os.path.join(d, "models", "model_1500.ckpt"))
            os.remove(os.path.join(d, "current.ckpt"))

            best_model = deepcopy(self.trainer.model)
            best_optimizer = deepcopy(self.trainer.optimizer)
            # Test with update_best
            checkpoint.save(2000, update_best=True)

            self.assertTrue(
                PathManager.exists(os.path.join(d, "models",
                                                "model_2000.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d, "best.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d,
                                                            "current.ckpt")))

            self._do_a_pass()
            checkpoint.save(2500)

            # Test resume
            self.trainer.config.checkpoint.resume = True

            current_model = deepcopy(self.trainer.model)
            current_optimizer = deepcopy(self.trainer.optimizer)
            checkpoint.load_state_dict()

            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    best_model.state_dict()))
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    current_model.state_dict()))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer,
                                         skip_keys=True))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer,
                                         skip_keys=True))

            base_0_weight_current = self.trainer.model.base[
                0].weight.data.clone()

            # Test resume_best
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.resume_best = True

            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    best_model.state_dict()))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer,
                                         skip_keys=True))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer,
                                         skip_keys=True))
            base_0_weight_best = self.trainer.model.base[0].weight.data.clone()

            self.trainer.config.checkpoint.resume_best = False
            # Test distributed settings
            self.trainer.model = torch.nn.DataParallel(self.trainer.model)
            checkpoint.load_state_dict()

            weight_to_be_tested = self.trainer.model.module.base[0].weight
            weight_device = weight_to_be_tested.device

            self.assertTrue(
                torch.equal(weight_to_be_tested,
                            base_0_weight_current.to(weight_device)))
            self.assertFalse(
                torch.equal(weight_to_be_tested,
                            base_0_weight_best.to(weight_device)))
Beispiel #11
0
 def test_save_config(self):
     with mock_env_with_temp() as d:
         Checkpoint(self.trainer)
         config = load_yaml(os.path.join(d, "config.yaml"))
         self.assertTrue(config == self.config)
         self.assertTrue(config == self.trainer.config)