def build_model(self, cfg, eval_only=False): # build_model might modify the cfg, thus clone cfg = cfg.clone() model = build_model(cfg) model_ema.may_build_model_ema(cfg, model) if cfg.MODEL.FROZEN_LAYER_REG_EXP: set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False) model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP) if cfg.QUANTIZATION.QAT.ENABLED: # Disable fake_quant and observer so that the model will be trained normally # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER). model = setup_qat_model( cfg, model, enable_fake_quant=eval_only, enable_observer=False ) if eval_only: checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) checkpointer.load(cfg.MODEL.WEIGHTS) model.eval() if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY: model_ema.apply_model_ema(model) # Note: the _visualize_model API is experimental if comm.is_main_process(): if hasattr(model, "_visualize_model"): logger.info("Adding model visualization ...") tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) model._visualize_model(tbx_writer) return model
def build_model(self, cfg, eval_only=False): # build_model might modify the cfg, thus clone cfg = cfg.clone() # silicon_qat_build_model_context is deprecated with silicon_qat_build_model_context(cfg): model = build_model(cfg) model_ema.may_build_model_ema(cfg, model) if cfg.MODEL.FROZEN_LAYER_REG_EXP: set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False) if cfg.QUANTIZATION.QAT.ENABLED: # Disable fake_quant and observer so that the model will be trained normally # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER). model = setup_qat_model(cfg, model, enable_fake_quant=eval_only, enable_observer=False) if eval_only: checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) checkpointer.load(cfg.MODEL.WEIGHTS) model.eval() if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY: model_ema.apply_model_ema(model) return model
def _build_model(self, cfg, eval_only=False): # build_model might modify the cfg, thus clone cfg = cfg.clone() model = build_model(cfg) model_ema.may_build_model_ema(cfg, model) if cfg.MODEL.FROZEN_LAYER_REG_EXP: set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False) model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP) if cfg.QUANTIZATION.QAT.ENABLED: # Disable fake_quant and observer so that the model will be trained normally # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER). if hasattr(model, "get_rand_input"): imsize = cfg.INPUT.MAX_SIZE_TRAIN rand_input = model.get_rand_input(imsize) example_inputs = (rand_input, {}) model = setup_qat_model( cfg, model, enable_fake_quant=eval_only, enable_observer=True, ) model(*example_inputs) else: imsize = cfg.INPUT.MAX_SIZE_TRAIN model = setup_qat_model( cfg, model, enable_fake_quant=eval_only, enable_observer=False, ) if eval_only: checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) checkpointer.load(cfg.MODEL.WEIGHTS) model.eval() if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY: model_ema.apply_model_ema(model) return model