def test_emastate_saveload(self): model = TestArch() state = model_ema.EMAState.FromModel(model) model1 = TestArch() self.assertFalse(_compare_state_dict(model, model1)) state1 = model_ema.EMAState() state1.load_state_dict(state.state_dict()) state1.apply_to(model1) self.assertTrue(_compare_state_dict(model, model1))
def test_ema_updater_decay(self): state = model_ema.EMAState() updater = model_ema.EMAUpdater(state, decay=0.7) updater.init_state(TestArch(1.0)) gt_val = 1.0 gt_val_int = 1 for idx in range(3): updater.update(TestArch(float(idx))) updated_model = state.get_ema_model(TestArch()) gt_val = gt_val * 0.7 + float(idx) * 0.3 gt_val_int = int(gt_val_int * 0.7 + float(idx) * 0.3) self.assertTrue( _compare_state_dict(updated_model, TestArch(gt_val, gt_val_int)))
def test_ema_updater(self): model = TestArch() state = model_ema.EMAState() updated_model = TestArch() updater = model_ema.EMAUpdater(state, decay=0.0) updater.init_state(model) for _ in range(3): cur = TestArch() updater.update(cur) state.apply_to(updated_model) # weight decay == 0.0, always use new model self.assertTrue(_compare_state_dict(updated_model, cur)) updater = model_ema.EMAUpdater(state, decay=1.0) updater.init_state(model) for _ in range(3): cur = TestArch() updater.update(cur) state.apply_to(updated_model) # weight decay == 1.0, always use init model self.assertTrue(_compare_state_dict(updated_model, model))