コード例 #1
0
ファイル: default_runner.py プロジェクト: Pandinosaurus/d2go
    def build_model(self, cfg, eval_only=False):
        # build_model might modify the cfg, thus clone
        cfg = cfg.clone()

        model = build_model(cfg)
        model_ema.may_build_model_ema(cfg, model)

        if cfg.MODEL.FROZEN_LAYER_REG_EXP:
            set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
            model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP)

        if cfg.QUANTIZATION.QAT.ENABLED:
            # Disable fake_quant and observer so that the model will be trained normally
            # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
            model = setup_qat_model(
                cfg, model, enable_fake_quant=eval_only, enable_observer=False
            )

        if eval_only:
            checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
            checkpointer.load(cfg.MODEL.WEIGHTS)
            model.eval()

            if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
                model_ema.apply_model_ema(model)

        # Note: the _visualize_model API is experimental
        if comm.is_main_process():
            if hasattr(model, "_visualize_model"):
                logger.info("Adding model visualization ...")
                tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
                model._visualize_model(tbx_writer)

        return model
コード例 #2
0
ファイル: default_runner.py プロジェクト: ananthsub/d2go
    def build_model(self, cfg, eval_only=False):
        # build_model might modify the cfg, thus clone
        cfg = cfg.clone()

        # silicon_qat_build_model_context is deprecated
        with silicon_qat_build_model_context(cfg):
            model = build_model(cfg)
            model_ema.may_build_model_ema(cfg, model)

        if cfg.MODEL.FROZEN_LAYER_REG_EXP:
            set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)

        if cfg.QUANTIZATION.QAT.ENABLED:
            # Disable fake_quant and observer so that the model will be trained normally
            # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
            model = setup_qat_model(cfg,
                                    model,
                                    enable_fake_quant=eval_only,
                                    enable_observer=False)

        if eval_only:
            checkpointer = self.build_checkpointer(cfg,
                                                   model,
                                                   save_dir=cfg.OUTPUT_DIR)
            checkpointer.load(cfg.MODEL.WEIGHTS)
            model.eval()

            if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
                model_ema.apply_model_ema(model)

        return model
コード例 #3
0
    def _build_model(self, cfg, eval_only=False):
        # build_model might modify the cfg, thus clone
        cfg = cfg.clone()

        model = build_model(cfg)
        model_ema.may_build_model_ema(cfg, model)

        if cfg.MODEL.FROZEN_LAYER_REG_EXP:
            set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
            model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP)

        if cfg.QUANTIZATION.QAT.ENABLED:
            # Disable fake_quant and observer so that the model will be trained normally
            # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
            if hasattr(model, "get_rand_input"):
                imsize = cfg.INPUT.MAX_SIZE_TRAIN
                rand_input = model.get_rand_input(imsize)
                example_inputs = (rand_input, {})
                model = setup_qat_model(
                    cfg,
                    model,
                    enable_fake_quant=eval_only,
                    enable_observer=True,
                )
                model(*example_inputs)
            else:
                imsize = cfg.INPUT.MAX_SIZE_TRAIN
                model = setup_qat_model(
                    cfg,
                    model,
                    enable_fake_quant=eval_only,
                    enable_observer=False,
                )

        if eval_only:
            checkpointer = self.build_checkpointer(cfg,
                                                   model,
                                                   save_dir=cfg.OUTPUT_DIR)
            checkpointer.load(cfg.MODEL.WEIGHTS)
            model.eval()

            if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
                model_ema.apply_model_ema(model)

        return model