Esempio n. 1
0
    def test_ema_hook(self):
        runner = default_runner.Detectron2GoRunner()
        cfg = runner.get_default_cfg()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL_EMA.ENABLED = True
        # use new model weights
        cfg.MODEL_EMA.DECAY = 0.0

        model = TestArch()
        model_ema.may_build_model_ema(cfg, model)
        self.assertTrue(hasattr(model, "ema_state"))

        ema_hook = model_ema.EMAHook(cfg, model)
        ema_hook.before_train()
        ema_hook.before_step()
        model.set_const_weights(2.0)
        ema_hook.after_step()
        ema_hook.after_train()

        ema_checkpointers = model_ema.may_get_ema_checkpointer(cfg, model)
        self.assertEqual(len(ema_checkpointers), 1)

        out_model = TestArch()
        ema_checkpointers["ema_state"].apply_to(out_model)
        self.assertTrue(_compare_state_dict(out_model, model))
Esempio n. 2
0
 def build_checkpointer(self, cfg, model, save_dir, **kwargs):
     kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model))
     checkpointer = QATCheckpointer(model, save_dir=save_dir, **kwargs)
     return checkpointer