def test_ema_hook(self): runner = default_runner.Detectron2GoRunner() cfg = runner.get_default_cfg() cfg.MODEL.DEVICE = "cpu" cfg.MODEL_EMA.ENABLED = True # use new model weights cfg.MODEL_EMA.DECAY = 0.0 model = TestArch() model_ema.may_build_model_ema(cfg, model) self.assertTrue(hasattr(model, "ema_state")) ema_hook = model_ema.EMAHook(cfg, model) ema_hook.before_train() ema_hook.before_step() model.set_const_weights(2.0) ema_hook.after_step() ema_hook.after_train() ema_checkpointers = model_ema.may_get_ema_checkpointer(cfg, model) self.assertEqual(len(ema_checkpointers), 1) out_model = TestArch() ema_checkpointers["ema_state"].apply_to(out_model) self.assertTrue(_compare_state_dict(out_model, model))
def build_model(self, cfg, eval_only=False): # build_model might modify the cfg, thus clone cfg = cfg.clone() model = build_model(cfg) model_ema.may_build_model_ema(cfg, model) if cfg.MODEL.FROZEN_LAYER_REG_EXP: set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False) model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP) if cfg.QUANTIZATION.QAT.ENABLED: # Disable fake_quant and observer so that the model will be trained normally # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER). model = setup_qat_model( cfg, model, enable_fake_quant=eval_only, enable_observer=False ) if eval_only: checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) checkpointer.load(cfg.MODEL.WEIGHTS) model.eval() if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY: model_ema.apply_model_ema(model) # Note: the _visualize_model API is experimental if comm.is_main_process(): if hasattr(model, "_visualize_model"): logger.info("Adding model visualization ...") tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) model._visualize_model(tbx_writer) return model
def build_model(self, cfg, eval_only=False): # build_model might modify the cfg, thus clone cfg = cfg.clone() # silicon_qat_build_model_context is deprecated with silicon_qat_build_model_context(cfg): model = build_model(cfg) model_ema.may_build_model_ema(cfg, model) if cfg.MODEL.FROZEN_LAYER_REG_EXP: set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False) if cfg.QUANTIZATION.QAT.ENABLED: # Disable fake_quant and observer so that the model will be trained normally # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER). model = setup_qat_model(cfg, model, enable_fake_quant=eval_only, enable_observer=False) if eval_only: checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) checkpointer.load(cfg.MODEL.WEIGHTS) model.eval() if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY: model_ema.apply_model_ema(model) return model
def _build_model(self, cfg, eval_only=False): # build_model might modify the cfg, thus clone cfg = cfg.clone() model = build_model(cfg) model_ema.may_build_model_ema(cfg, model) if cfg.MODEL.FROZEN_LAYER_REG_EXP: set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False) model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP) if cfg.QUANTIZATION.QAT.ENABLED: # Disable fake_quant and observer so that the model will be trained normally # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER). if hasattr(model, "get_rand_input"): imsize = cfg.INPUT.MAX_SIZE_TRAIN rand_input = model.get_rand_input(imsize) example_inputs = (rand_input, {}) model = setup_qat_model( cfg, model, enable_fake_quant=eval_only, enable_observer=True, ) model(*example_inputs) else: imsize = cfg.INPUT.MAX_SIZE_TRAIN model = setup_qat_model( cfg, model, enable_fake_quant=eval_only, enable_observer=False, ) if eval_only: checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) checkpointer.load(cfg.MODEL.WEIGHTS) model.eval() if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY: model_ema.apply_model_ema(model) return model