예제 #1
0
    def test_checkpoint_scaler_loading(self):
        with mock_env_with_temp():
            original_scaler = deepcopy(self.trainer.scaler)

            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)

            self._do_a_fp16_pass()
            checkpoint.save(1000)
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = True
            self.trainer.config.checkpoint.reset.counts = True
            self.trainer.config.checkpoint.reset.fp16_scaler = True

            # Reset to make it same as the default grad scaler
            self.trainer.scaler = torch.cuda.amp.GradScaler()
            checkpoint.load_state_dict()
            self.assertTrue(
                compare_state_dicts(self.trainer.scaler.state_dict(),
                                    original_scaler.state_dict()))

            self._do_a_fp16_pass()
            checkpoint.save(2000)
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = True
            self.trainer.config.checkpoint.reset.counts = True
            self.trainer.config.checkpoint.reset.fp16_scaler = False

            # Reset again to make it same as the default grad scaler
            self.trainer.scaler = torch.cuda.amp.GradScaler()
            checkpoint.load_state_dict()
            self.assertFalse(
                compare_state_dicts(self.trainer.scaler.state_dict(),
                                    original_scaler.state_dict()))
예제 #2
0
    def test_finalize_and_resume_file(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            checkpoint.finalize()
            original = deepcopy(self.trainer.model)
            pth_path = os.path.join(d, "simple_final.pth")
            self.assertTrue(PathManager.exists(pth_path))

            self._do_a_pass()

            after_a_pass = deepcopy(self.trainer.model)
            original_optimizer = deepcopy(self.trainer.optimizer)
            self.trainer.config.checkpoint.resume_file = pth_path

            with contextlib.redirect_stdout(StringIO()):
                checkpoint.load_state_dict()
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original.state_dict()))
            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    after_a_pass.state_dict()))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
            # Keys will not be same as we just updated the model
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))
예제 #3
0
    def test_pretrained_load(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            original_model = deepcopy(self.trainer.model)
            # Test with zoo now
            ret_load_pretrained_zoo = {
                "config": self.config.model_config,
                "checkpoint": deepcopy(self.trainer.model.state_dict()),
                "full_config": self.config,
            }

            checkpoint.save(2000)
            self.trainer.config.checkpoint.resume_file = os.path.join(d, "current.ckpt")
            self.trainer.config.checkpoint.resume_pretrained = True
            self.trainer.model = OnlyBase()
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(
                    self.trainer.model.base_test.state_dict(),
                    original_model.base.state_dict(),
                )
            )

            with patch(
                "mmf.utils.checkpoint.load_pretrained_model",
                return_value=ret_load_pretrained_zoo,
            ):
                self.trainer.config.checkpoint.resume_zoo = "random"
                self.trainer.config.checkpoint.resume_file = None
                self.trainer.model = OnlyBase()
                checkpoint.load_state_dict()

                self.assertTrue(
                    compare_state_dicts(
                        self.trainer.model.base_test.state_dict(),
                        original_model.base.state_dict(),
                    )
                )
예제 #4
0
    def test_zoo_load(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()

            original_model = deepcopy(self.trainer.model)
            ret_load_pretrained_zoo = {
                "config": self.config.model_config,
                "checkpoint": deepcopy(self.trainer.model.state_dict()),
                "full_config": self.config,
            }

            self._do_a_pass()

            with patch(
                "mmf.utils.checkpoint.load_pretrained_model",
                return_value=ret_load_pretrained_zoo,
            ):
                self.trainer.config.checkpoint.resume_zoo = "random"
                with contextlib.redirect_stdout(StringIO()):
                    checkpoint.load_state_dict()
                self.assertTrue(
                    compare_state_dicts(
                        self.trainer.model.state_dict(), original_model.state_dict()
                    )
                )

                # Now, test zoo override
                self.trainer.config.checkpoint.zoo_config_override = True
                SimpleModule.from_pretrained = Mock(
                    return_value=deepcopy(original_model)
                )
                registry.register_model("simple")(SimpleModule)
                with contextlib.redirect_stdout(StringIO()):
                    checkpoint.load_state_dict()
                self.assertTrue(
                    compare_state_dicts(
                        self.trainer.model.state_dict(), original_model.state_dict()
                    )
                )
예제 #5
0
    def _compare_optimizers(self, a, b):
        state_dict_a = a.state_dict()
        state_dict_b = b.state_dict()
        state_a = state_dict_a["state"]
        state_b = state_dict_b["state"]

        same = True
        same = same and list(state_a.keys()) == list(state_b.keys())
        same = same and state_dict_a["param_groups"] == state_dict_b["param_groups"]

        for item1, item2 in zip(state_a.values(), state_b.values()):
            same = same and compare_state_dicts(item1, item2)

        return same
예제 #6
0
    def test_finalize_and_restore_from_it(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            original_model = deepcopy(self.trainer.model)
            self._do_a_pass()
            model_1500 = deepcopy(self.trainer.model)
            checkpoint.save(1500)

            swap = self.trainer.model
            self.trainer.model = original_model
            checkpoint.restore()
            # First test without best.ckpt
            self.assertTrue(
                compare_state_dicts(
                    self.trainer.model.state_dict(), original_model.state_dict()
                )
            )
            self.assertFalse(
                compare_state_dicts(
                    self.trainer.model.state_dict(), model_1500.state_dict()
                )
            )

            self.trainer.model = swap

            self._do_a_pass()
            model_2000 = deepcopy(self.trainer.model)
            checkpoint.save(2000, update_best=True)

            self._do_a_pass()
            model_2500 = deepcopy(self.trainer.model)
            checkpoint.save(2500)

            checkpoint.restore()

            self.assertFalse(
                compare_state_dicts(
                    self.trainer.model.state_dict(), original_model.state_dict()
                )
            )
            self.assertFalse(
                compare_state_dicts(
                    self.trainer.model.state_dict(), model_1500.state_dict()
                )
            )
            self.assertTrue(
                compare_state_dicts(
                    self.trainer.model.state_dict(), model_2000.state_dict()
                )
            )
            self.assertFalse(
                compare_state_dicts(
                    self.trainer.model.state_dict(), model_2500.state_dict()
                )
            )
예제 #7
0
    def test_resets(self):
        with mock_env_with_temp():
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()

            original_optimizer = deepcopy(self.trainer.optimizer)
            original_model = deepcopy(self.trainer.model)

            self.trainer.current_epoch = 3
            checkpoint.save(2000, update_best=True)
            self.trainer.current_epoch = 4
            # Test reset all
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.reset.all = True
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))
            self.assertEqual(self.trainer.num_updates, 0)
            self.assertEqual(self.trainer.current_iteration, 0)
            self.assertEqual(self.trainer.current_epoch, 4)

            # Test reset_optimizer
            self._init_early_stopping(checkpoint)
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = True
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))

            self.assertEqual(self.trainer.num_updates, 2000)
            self.assertEqual(self.trainer.current_iteration, 2000)
            self.assertEqual(self.trainer.current_epoch, 3)

            self._init_early_stopping(checkpoint)
            # Test reset_counts
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.reset.optimizer = False
            self.trainer.config.checkpoint.reset.counts = True
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))
            self.assertEqual(self.trainer.num_updates, 0)
            self.assertEqual(self.trainer.current_iteration, 0)
            self.assertEqual(self.trainer.current_epoch, 2)

            # Test with resume_best
            self._do_a_pass()
            checkpoint.save(3000)
            self._init_early_stopping(checkpoint)
            self.trainer.config.checkpoint.reset.all = False
            self.trainer.config.checkpoint.resume_best = True
            self.trainer.config.checkpoint.reset.optimizer = True
            self.trainer.config.checkpoint.reset.counts = False
            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    original_model.state_dict()))

            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         original_optimizer,
                                         skip_keys=True))
            self.assertEqual(self.trainer.num_updates, 1000)
            self.assertEqual(self.trainer.current_iteration, 1000)
            self.assertEqual(self.trainer.current_epoch, 3)
예제 #8
0
    def test_save_and_load_state_dict(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            # Test normal case
            checkpoint.save(1500)

            self.assertTrue(
                PathManager.exists(os.path.join(d, "models",
                                                "model_1500.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d,
                                                            "current.ckpt")))
            self.assertFalse(PathManager.exists(os.path.join(d, "best.ckpt")))
            os.remove(os.path.join(d, "models", "model_1500.ckpt"))
            os.remove(os.path.join(d, "current.ckpt"))

            best_model = deepcopy(self.trainer.model)
            best_optimizer = deepcopy(self.trainer.optimizer)
            # Test with update_best
            checkpoint.save(2000, update_best=True)

            self.assertTrue(
                PathManager.exists(os.path.join(d, "models",
                                                "model_2000.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d, "best.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d,
                                                            "current.ckpt")))

            self._do_a_pass()
            checkpoint.save(2500)

            # Test resume
            self.trainer.config.checkpoint.resume = True

            current_model = deepcopy(self.trainer.model)
            current_optimizer = deepcopy(self.trainer.optimizer)
            checkpoint.load_state_dict()

            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    best_model.state_dict()))
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    current_model.state_dict()))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer,
                                         skip_keys=True))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer,
                                         skip_keys=True))

            base_0_weight_current = self.trainer.model.base[
                0].weight.data.clone()

            # Test resume_best
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.resume_best = True

            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    best_model.state_dict()))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer,
                                         skip_keys=True))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer,
                                         skip_keys=True))
            base_0_weight_best = self.trainer.model.base[0].weight.data.clone()

            self.trainer.config.checkpoint.resume_best = False
            # Test distributed settings
            self.trainer.model = torch.nn.DataParallel(self.trainer.model)
            checkpoint.load_state_dict()

            weight_to_be_tested = self.trainer.model.module.base[0].weight
            weight_device = weight_to_be_tested.device

            self.assertTrue(
                torch.equal(weight_to_be_tested,
                            base_0_weight_current.to(weight_device)))
            self.assertFalse(
                torch.equal(weight_to_be_tested,
                            base_0_weight_best.to(weight_device)))