def test_ema(self): model = DummyModule() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) state = deepcopy(model.state_dict()) config = EMAConfig() ema = EMA(model, config) # set decay ema._set_decay(config.ema_decay) self.assertEqual(ema.get_decay(), config.ema_decay) # get model self.assertEqual(ema.get_model(), ema.model) # Since fp32 params is not used, it should be of size 0 self.assertEqual(len(ema.fp32_params), 0) # EMA step x = torch.randn(32) y = model(x) loss = y.sum() loss.backward() optimizer.step() ema.step(model) ema_state_dict = ema.get_model().state_dict() for key, param in model.state_dict().items(): prev_param = state[key] ema_param = ema_state_dict[key] if "version" in key: # Do not decay a model.version pytorch param continue self.assertTorchAllClose( ema_param, config.ema_decay * prev_param + (1 - config.ema_decay) * param, ) # Since fp32 params is not used, it should be of size 0 self.assertEqual(len(ema.fp32_params), 0) # Load EMA into model model2 = DummyModule() ema.reverse(model2) for key, param in model2.state_dict().items(): ema_param = ema_state_dict[key] self.assertTrue(torch.allclose(ema_param, param)) # Check that step_internal is called once with patch.object( ema, "_step_internal", return_value=None ) as mock_method: ema.step(model) mock_method.assert_called_once_with(model, None)
def test_ema_fp16(self): model = DummyModule().half() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) state = deepcopy(model.state_dict()) config = EMAConfig(ema_fp32=False) ema = EMA(model, config) # Since fp32 params is not used, it should be of size 0 self.assertEqual(len(ema.fp32_params), 0) x = torch.randn(32) y = model(x.half()) loss = y.sum() loss.backward() optimizer.step() ema.step(model) for key, param in model.state_dict().items(): prev_param = state[key] ema_param = ema.get_model().state_dict()[key] if "version" in key: # Do not decay a model.version pytorch param continue # EMA update is done in fp16, and hence the EMA param must be # closer to the EMA update done in fp16 than in fp32. self.assertLessEqual( torch.norm( ema_param.float() - ( config.ema_decay * prev_param + (1 - config.ema_decay) * param ).float() ), torch.norm( ema_param.float() - ( config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float() ) .half() .float() ), ) self.assertTorchAllClose( ema_param, config.ema_decay * prev_param + (1 - config.ema_decay) * param, ) # Since fp32 params is not used, it should be of size 0 self.assertEqual(len(ema.fp32_params), 0)
def _test_ema_start_update(self, updates): model = DummyModule() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) state = deepcopy(model.state_dict()) config = EMAConfig(ema_start_update=1) ema = EMA(model, config) # EMA step x = torch.randn(32) y = model(x) loss = y.sum() loss.backward() optimizer.step() ema.step(model, updates=updates) ema_state_dict = ema.get_model().state_dict() self.assertEqual(ema.get_decay(), 0 if updates == 0 else config.ema_decay) for key, param in model.state_dict().items(): ema_param = ema_state_dict[key] prev_param = state[key] if "version" in key: # Do not decay a model.version pytorch param continue if updates == 0: self.assertTorchAllClose( ema_param, param, ) else: self.assertTorchAllClose( ema_param, config.ema_decay * prev_param + (1 - config.ema_decay) * param, ) # Check that step_internal is called once with patch.object(ema, "_step_internal", return_value=None) as mock_method: ema.step(model, updates=updates) mock_method.assert_called_once_with(model, updates)
def test_ema_fp32(self): # CPU no longer supports Linear in half precision dtype = torch.half if torch.cuda.is_available() else torch.float model = DummyModule().to(dtype) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) state = deepcopy(model.state_dict()) config = EMAConfig(ema_fp32=True) ema = EMA(model, config) x = torch.randn(32) y = model(x.to(dtype)) loss = y.sum() loss.backward() optimizer.step() ema.step(model) for key, param in model.state_dict().items(): prev_param = state[key] ema_param = ema.get_model().state_dict()[key] if "version" in key: # Do not decay a model.version pytorch param continue self.assertIn(key, ema.fp32_params) # EMA update is done in fp32, and hence the EMA param must be # closer to the EMA update done in fp32 than in fp16. self.assertLessEqual( torch.norm(ema_param.float() - ( config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).to(dtype).float()), torch.norm(ema_param.float() - (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float()), ) self.assertTorchAllClose( ema_param, (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).to(dtype), )