Exemple #1
0
    def test_emastate_saveload(self):
        model = TestArch()
        state = model_ema.EMAState.FromModel(model)

        model1 = TestArch()
        self.assertFalse(_compare_state_dict(model, model1))

        state1 = model_ema.EMAState()
        state1.load_state_dict(state.state_dict())
        state1.apply_to(model1)
        self.assertTrue(_compare_state_dict(model, model1))
Exemple #2
0
    def test_ema_updater_decay(self):
        state = model_ema.EMAState()

        updater = model_ema.EMAUpdater(state, decay=0.7)
        updater.init_state(TestArch(1.0))
        gt_val = 1.0
        gt_val_int = 1
        for idx in range(3):
            updater.update(TestArch(float(idx)))
            updated_model = state.get_ema_model(TestArch())
            gt_val = gt_val * 0.7 + float(idx) * 0.3
            gt_val_int = int(gt_val_int * 0.7 + float(idx) * 0.3)
            self.assertTrue(
                _compare_state_dict(updated_model,
                                    TestArch(gt_val, gt_val_int)))
Exemple #3
0
    def test_ema_updater(self):
        model = TestArch()
        state = model_ema.EMAState()

        updated_model = TestArch()

        updater = model_ema.EMAUpdater(state, decay=0.0)
        updater.init_state(model)
        for _ in range(3):
            cur = TestArch()
            updater.update(cur)
            state.apply_to(updated_model)
            # weight decay == 0.0, always use new model
            self.assertTrue(_compare_state_dict(updated_model, cur))

        updater = model_ema.EMAUpdater(state, decay=1.0)
        updater.init_state(model)
        for _ in range(3):
            cur = TestArch()
            updater.update(cur)
            state.apply_to(updated_model)
            # weight decay == 1.0, always use init model
            self.assertTrue(_compare_state_dict(updated_model, model))