def test_ema_hook(self): runner = default_runner.Detectron2GoRunner() cfg = runner.get_default_cfg() cfg.MODEL.DEVICE = "cpu" cfg.MODEL_EMA.ENABLED = True # use new model weights cfg.MODEL_EMA.DECAY = 0.0 model = TestArch() model_ema.may_build_model_ema(cfg, model) self.assertTrue(hasattr(model, "ema_state")) ema_hook = model_ema.EMAHook(cfg, model) ema_hook.before_train() ema_hook.before_step() model.set_const_weights(2.0) ema_hook.after_step() ema_hook.after_train() ema_checkpointers = model_ema.may_get_ema_checkpointer(cfg, model) self.assertEqual(len(ema_checkpointers), 1) out_model = TestArch() ema_checkpointers["ema_state"].apply_to(out_model) self.assertTrue(_compare_state_dict(out_model, model))
def _get_trainer_hooks(self, cfg, model, optimizer, scheduler, periodic_checkpointer, trainer): return [ hooks.IterationTimer(), model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None, self._create_data_loader_hook(cfg), self._create_after_step_hook(cfg, model, optimizer, scheduler, periodic_checkpointer), hooks.EvalHook( cfg.TEST.EVAL_PERIOD, lambda: self.do_test(cfg, model, train_iter=trainer.iter), eval_after_train= False, # done by a separate do_test call in tools/train_net.py ), kmeans_anchors.compute_kmeans_anchors_hook(self, cfg), self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None, ]
def do_train(self, cfg, model, resume): add_print_flops_callback(cfg, model, disable_after_callback=True) optimizer = self.build_optimizer(cfg, model) scheduler = self.build_lr_scheduler(cfg, optimizer) checkpointer = self.build_checkpointer( cfg, model, save_dir=cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler, ) checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume) start_iter = (checkpoint.get("iteration", -1) if resume and checkpointer.has_checkpoint() else -1) # The checkpoint stores the training iteration that just finished, thus we start # at the next iteration (or iter zero if there's no checkpoint). start_iter += 1 max_iter = cfg.SOLVER.MAX_ITER periodic_checkpointer = PeriodicCheckpointer( checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter) data_loader = self.build_detection_train_loader(cfg) def _get_model_with_abnormal_checker(model): if not cfg.ABNORMAL_CHECKER.ENABLED: return model tbx_writer = _get_tbx_writer( get_tensorboard_log_dir(cfg.OUTPUT_DIR)) writers = abnormal_checker.get_writers(cfg, tbx_writer) checker = abnormal_checker.AbnormalLossChecker(start_iter, writers) ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker) return ret trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( _get_model_with_abnormal_checker(model), data_loader, optimizer) trainer_hooks = [ hooks.IterationTimer(), model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None, self._create_after_step_hook(cfg, model, optimizer, scheduler, periodic_checkpointer), hooks.EvalHook( cfg.TEST.EVAL_PERIOD, lambda: self.do_test(cfg, model, train_iter=trainer.iter), ), kmeans_anchors.compute_kmeans_anchors_hook(self, cfg), self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None, ] if comm.is_main_process(): tbx_writer = _get_tbx_writer( get_tensorboard_log_dir(cfg.OUTPUT_DIR)) writers = [ CommonMetricPrinter(max_iter), JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), tbx_writer, ] trainer_hooks.append(hooks.PeriodicWriter(writers)) trainer.register_hooks(trainer_hooks) trainer.train(start_iter, max_iter) if hasattr(self, 'original_cfg'): table = get_cfg_diff_table(cfg, self.original_cfg) logger.info( "GeneralizeRCNN Runner ignoring training config change: \n" + table) trained_cfg = self.original_cfg.clone() else: trained_cfg = cfg.clone() with temp_defrost(trained_cfg): trained_cfg.MODEL.WEIGHTS = checkpointer.get_checkpoint_file() return {"model_final": trained_cfg}