예제 #1
0
    def test_ema_hook(self):
        runner = default_runner.Detectron2GoRunner()
        cfg = runner.get_default_cfg()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL_EMA.ENABLED = True
        # use new model weights
        cfg.MODEL_EMA.DECAY = 0.0

        model = TestArch()
        model_ema.may_build_model_ema(cfg, model)
        self.assertTrue(hasattr(model, "ema_state"))

        ema_hook = model_ema.EMAHook(cfg, model)
        ema_hook.before_train()
        ema_hook.before_step()
        model.set_const_weights(2.0)
        ema_hook.after_step()
        ema_hook.after_train()

        ema_checkpointers = model_ema.may_get_ema_checkpointer(cfg, model)
        self.assertEqual(len(ema_checkpointers), 1)

        out_model = TestArch()
        ema_checkpointers["ema_state"].apply_to(out_model)
        self.assertTrue(_compare_state_dict(out_model, model))
예제 #2
0
    def build_model(self, cfg, eval_only=False):
        # build_model might modify the cfg, thus clone
        cfg = cfg.clone()

        model = build_model(cfg)
        model_ema.may_build_model_ema(cfg, model)

        if cfg.MODEL.FROZEN_LAYER_REG_EXP:
            set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
            model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP)

        if cfg.QUANTIZATION.QAT.ENABLED:
            # Disable fake_quant and observer so that the model will be trained normally
            # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
            model = setup_qat_model(
                cfg, model, enable_fake_quant=eval_only, enable_observer=False
            )

        if eval_only:
            checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
            checkpointer.load(cfg.MODEL.WEIGHTS)
            model.eval()

            if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
                model_ema.apply_model_ema(model)

        # Note: the _visualize_model API is experimental
        if comm.is_main_process():
            if hasattr(model, "_visualize_model"):
                logger.info("Adding model visualization ...")
                tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
                model._visualize_model(tbx_writer)

        return model
예제 #3
0
    def build_model(self, cfg, eval_only=False):
        # build_model might modify the cfg, thus clone
        cfg = cfg.clone()

        # silicon_qat_build_model_context is deprecated
        with silicon_qat_build_model_context(cfg):
            model = build_model(cfg)
            model_ema.may_build_model_ema(cfg, model)

        if cfg.MODEL.FROZEN_LAYER_REG_EXP:
            set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)

        if cfg.QUANTIZATION.QAT.ENABLED:
            # Disable fake_quant and observer so that the model will be trained normally
            # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
            model = setup_qat_model(cfg,
                                    model,
                                    enable_fake_quant=eval_only,
                                    enable_observer=False)

        if eval_only:
            checkpointer = self.build_checkpointer(cfg,
                                                   model,
                                                   save_dir=cfg.OUTPUT_DIR)
            checkpointer.load(cfg.MODEL.WEIGHTS)
            model.eval()

            if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
                model_ema.apply_model_ema(model)

        return model
예제 #4
0
    def _build_model(self, cfg, eval_only=False):
        # build_model might modify the cfg, thus clone
        cfg = cfg.clone()

        model = build_model(cfg)
        model_ema.may_build_model_ema(cfg, model)

        if cfg.MODEL.FROZEN_LAYER_REG_EXP:
            set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
            model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP)

        if cfg.QUANTIZATION.QAT.ENABLED:
            # Disable fake_quant and observer so that the model will be trained normally
            # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
            if hasattr(model, "get_rand_input"):
                imsize = cfg.INPUT.MAX_SIZE_TRAIN
                rand_input = model.get_rand_input(imsize)
                example_inputs = (rand_input, {})
                model = setup_qat_model(
                    cfg,
                    model,
                    enable_fake_quant=eval_only,
                    enable_observer=True,
                )
                model(*example_inputs)
            else:
                imsize = cfg.INPUT.MAX_SIZE_TRAIN
                model = setup_qat_model(
                    cfg,
                    model,
                    enable_fake_quant=eval_only,
                    enable_observer=False,
                )

        if eval_only:
            checkpointer = self.build_checkpointer(cfg,
                                                   model,
                                                   save_dir=cfg.OUTPUT_DIR)
            checkpointer.load(cfg.MODEL.WEIGHTS)
            model.eval()

            if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
                model_ema.apply_model_ema(model)

        return model