def test_checkpoint_scaler_loading(self): with mock_env_with_temp(): original_scaler = deepcopy(self.trainer.scaler) checkpoint = Checkpoint(self.trainer) self._init_early_stopping(checkpoint) self._do_a_fp16_pass() checkpoint.save(1000) self.trainer.config.checkpoint.resume = True self.trainer.config.checkpoint.reset.all = False self.trainer.config.checkpoint.reset.optimizer = True self.trainer.config.checkpoint.reset.counts = True self.trainer.config.checkpoint.reset.fp16_scaler = True # Reset to make it same as the default grad scaler self.trainer.scaler = torch.cuda.amp.GradScaler() checkpoint.load_state_dict() self.assertTrue( compare_state_dicts(self.trainer.scaler.state_dict(), original_scaler.state_dict())) self._do_a_fp16_pass() checkpoint.save(2000) self.trainer.config.checkpoint.reset.all = False self.trainer.config.checkpoint.reset.optimizer = True self.trainer.config.checkpoint.reset.counts = True self.trainer.config.checkpoint.reset.fp16_scaler = False # Reset again to make it same as the default grad scaler self.trainer.scaler = torch.cuda.amp.GradScaler() checkpoint.load_state_dict() self.assertFalse( compare_state_dicts(self.trainer.scaler.state_dict(), original_scaler.state_dict()))
def test_finalize_and_resume_file(self): with mock_env_with_temp() as d: checkpoint = Checkpoint(self.trainer) self._init_early_stopping(checkpoint) self._do_a_pass() checkpoint.finalize() original = deepcopy(self.trainer.model) pth_path = os.path.join(d, "simple_final.pth") self.assertTrue(PathManager.exists(pth_path)) self._do_a_pass() after_a_pass = deepcopy(self.trainer.model) original_optimizer = deepcopy(self.trainer.optimizer) self.trainer.config.checkpoint.resume_file = pth_path with contextlib.redirect_stdout(StringIO()): checkpoint.load_state_dict() self.assertTrue( compare_state_dicts(self.trainer.model.state_dict(), original.state_dict())) self.assertFalse( compare_state_dicts(self.trainer.model.state_dict(), after_a_pass.state_dict())) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, original_optimizer)) # Keys will not be same as we just updated the model self.assertTrue( self._compare_optimizers(self.trainer.optimizer, original_optimizer, skip_keys=True))
def test_pretrained_load(self): with mock_env_with_temp() as d: checkpoint = Checkpoint(self.trainer) self._init_early_stopping(checkpoint) self._do_a_pass() original_model = deepcopy(self.trainer.model) # Test with zoo now ret_load_pretrained_zoo = { "config": self.config.model_config, "checkpoint": deepcopy(self.trainer.model.state_dict()), "full_config": self.config, } checkpoint.save(2000) self.trainer.config.checkpoint.resume_file = os.path.join(d, "current.ckpt") self.trainer.config.checkpoint.resume_pretrained = True self.trainer.model = OnlyBase() checkpoint.load_state_dict() self.assertTrue( compare_state_dicts( self.trainer.model.base_test.state_dict(), original_model.base.state_dict(), ) ) with patch( "mmf.utils.checkpoint.load_pretrained_model", return_value=ret_load_pretrained_zoo, ): self.trainer.config.checkpoint.resume_zoo = "random" self.trainer.config.checkpoint.resume_file = None self.trainer.model = OnlyBase() checkpoint.load_state_dict() self.assertTrue( compare_state_dicts( self.trainer.model.base_test.state_dict(), original_model.base.state_dict(), ) )
def test_zoo_load(self): with mock_env_with_temp(): checkpoint = Checkpoint(self.trainer) self._init_early_stopping(checkpoint) self._do_a_pass() original_model = deepcopy(self.trainer.model) ret_load_pretrained_zoo = { "config": self.config.model_config, "checkpoint": deepcopy(self.trainer.model.state_dict()), "full_config": self.config, } self._do_a_pass() with patch( "mmf.utils.checkpoint.load_pretrained_model", return_value=ret_load_pretrained_zoo, ): self.trainer.config.checkpoint.resume_zoo = "random" with contextlib.redirect_stdout(StringIO()): checkpoint.load_state_dict() self.assertTrue( compare_state_dicts( self.trainer.model.state_dict(), original_model.state_dict() ) ) # Now, test zoo override self.trainer.config.checkpoint.zoo_config_override = True SimpleModule.from_pretrained = Mock( return_value=deepcopy(original_model) ) registry.register_model("simple")(SimpleModule) with contextlib.redirect_stdout(StringIO()): checkpoint.load_state_dict() self.assertTrue( compare_state_dicts( self.trainer.model.state_dict(), original_model.state_dict() ) )
def _compare_optimizers(self, a, b): state_dict_a = a.state_dict() state_dict_b = b.state_dict() state_a = state_dict_a["state"] state_b = state_dict_b["state"] same = True same = same and list(state_a.keys()) == list(state_b.keys()) same = same and state_dict_a["param_groups"] == state_dict_b["param_groups"] for item1, item2 in zip(state_a.values(), state_b.values()): same = same and compare_state_dicts(item1, item2) return same
def test_finalize_and_restore_from_it(self): with mock_env_with_temp(): checkpoint = Checkpoint(self.trainer) self._init_early_stopping(checkpoint) original_model = deepcopy(self.trainer.model) self._do_a_pass() model_1500 = deepcopy(self.trainer.model) checkpoint.save(1500) swap = self.trainer.model self.trainer.model = original_model checkpoint.restore() # First test without best.ckpt self.assertTrue( compare_state_dicts( self.trainer.model.state_dict(), original_model.state_dict() ) ) self.assertFalse( compare_state_dicts( self.trainer.model.state_dict(), model_1500.state_dict() ) ) self.trainer.model = swap self._do_a_pass() model_2000 = deepcopy(self.trainer.model) checkpoint.save(2000, update_best=True) self._do_a_pass() model_2500 = deepcopy(self.trainer.model) checkpoint.save(2500) checkpoint.restore() self.assertFalse( compare_state_dicts( self.trainer.model.state_dict(), original_model.state_dict() ) ) self.assertFalse( compare_state_dicts( self.trainer.model.state_dict(), model_1500.state_dict() ) ) self.assertTrue( compare_state_dicts( self.trainer.model.state_dict(), model_2000.state_dict() ) ) self.assertFalse( compare_state_dicts( self.trainer.model.state_dict(), model_2500.state_dict() ) )
def test_resets(self): with mock_env_with_temp(): checkpoint = Checkpoint(self.trainer) self._init_early_stopping(checkpoint) self._do_a_pass() original_optimizer = deepcopy(self.trainer.optimizer) original_model = deepcopy(self.trainer.model) self.trainer.current_epoch = 3 checkpoint.save(2000, update_best=True) self.trainer.current_epoch = 4 # Test reset all self.trainer.config.checkpoint.resume = True self.trainer.config.checkpoint.reset.all = True checkpoint.load_state_dict() self.assertTrue( compare_state_dicts(self.trainer.model.state_dict(), original_model.state_dict())) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, original_optimizer)) self.assertTrue( self._compare_optimizers(self.trainer.optimizer, original_optimizer, skip_keys=True)) self.assertEqual(self.trainer.num_updates, 0) self.assertEqual(self.trainer.current_iteration, 0) self.assertEqual(self.trainer.current_epoch, 4) # Test reset_optimizer self._init_early_stopping(checkpoint) self.trainer.config.checkpoint.reset.all = False self.trainer.config.checkpoint.reset.optimizer = True checkpoint.load_state_dict() self.assertTrue( compare_state_dicts(self.trainer.model.state_dict(), original_model.state_dict())) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, original_optimizer)) self.assertTrue( self._compare_optimizers(self.trainer.optimizer, original_optimizer, skip_keys=True)) self.assertEqual(self.trainer.num_updates, 2000) self.assertEqual(self.trainer.current_iteration, 2000) self.assertEqual(self.trainer.current_epoch, 3) self._init_early_stopping(checkpoint) # Test reset_counts self.trainer.config.checkpoint.reset.all = False self.trainer.config.checkpoint.reset.optimizer = False self.trainer.config.checkpoint.reset.counts = True checkpoint.load_state_dict() self.assertTrue( compare_state_dicts(self.trainer.model.state_dict(), original_model.state_dict())) self.assertTrue( self._compare_optimizers(self.trainer.optimizer, original_optimizer, skip_keys=True)) self.assertEqual(self.trainer.num_updates, 0) self.assertEqual(self.trainer.current_iteration, 0) self.assertEqual(self.trainer.current_epoch, 2) # Test with resume_best self._do_a_pass() checkpoint.save(3000) self._init_early_stopping(checkpoint) self.trainer.config.checkpoint.reset.all = False self.trainer.config.checkpoint.resume_best = True self.trainer.config.checkpoint.reset.optimizer = True self.trainer.config.checkpoint.reset.counts = False checkpoint.load_state_dict() self.assertTrue( compare_state_dicts(self.trainer.model.state_dict(), original_model.state_dict())) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, original_optimizer)) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, original_optimizer, skip_keys=True)) self.assertEqual(self.trainer.num_updates, 1000) self.assertEqual(self.trainer.current_iteration, 1000) self.assertEqual(self.trainer.current_epoch, 3)
def test_save_and_load_state_dict(self): with mock_env_with_temp() as d: checkpoint = Checkpoint(self.trainer) self._init_early_stopping(checkpoint) self._do_a_pass() # Test normal case checkpoint.save(1500) self.assertTrue( PathManager.exists(os.path.join(d, "models", "model_1500.ckpt"))) self.assertTrue(PathManager.exists(os.path.join(d, "current.ckpt"))) self.assertFalse(PathManager.exists(os.path.join(d, "best.ckpt"))) os.remove(os.path.join(d, "models", "model_1500.ckpt")) os.remove(os.path.join(d, "current.ckpt")) best_model = deepcopy(self.trainer.model) best_optimizer = deepcopy(self.trainer.optimizer) # Test with update_best checkpoint.save(2000, update_best=True) self.assertTrue( PathManager.exists(os.path.join(d, "models", "model_2000.ckpt"))) self.assertTrue(PathManager.exists(os.path.join(d, "best.ckpt"))) self.assertTrue(PathManager.exists(os.path.join(d, "current.ckpt"))) self._do_a_pass() checkpoint.save(2500) # Test resume self.trainer.config.checkpoint.resume = True current_model = deepcopy(self.trainer.model) current_optimizer = deepcopy(self.trainer.optimizer) checkpoint.load_state_dict() self.assertFalse( compare_state_dicts(self.trainer.model.state_dict(), best_model.state_dict())) self.assertTrue( compare_state_dicts(self.trainer.model.state_dict(), current_model.state_dict())) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, best_optimizer)) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, best_optimizer, skip_keys=True)) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, current_optimizer)) self.assertTrue( self._compare_optimizers(self.trainer.optimizer, current_optimizer, skip_keys=True)) base_0_weight_current = self.trainer.model.base[ 0].weight.data.clone() # Test resume_best self.trainer.config.checkpoint.resume = True self.trainer.config.checkpoint.resume_best = True checkpoint.load_state_dict() self.assertTrue( compare_state_dicts(self.trainer.model.state_dict(), best_model.state_dict())) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, best_optimizer)) self.assertTrue( self._compare_optimizers(self.trainer.optimizer, best_optimizer, skip_keys=True)) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, current_optimizer)) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, current_optimizer, skip_keys=True)) base_0_weight_best = self.trainer.model.base[0].weight.data.clone() self.trainer.config.checkpoint.resume_best = False # Test distributed settings self.trainer.model = torch.nn.DataParallel(self.trainer.model) checkpoint.load_state_dict() weight_to_be_tested = self.trainer.model.module.base[0].weight weight_device = weight_to_be_tested.device self.assertTrue( torch.equal(weight_to_be_tested, base_0_weight_current.to(weight_device))) self.assertFalse( torch.equal(weight_to_be_tested, base_0_weight_best.to(weight_device)))