def build_model(self, cfg, eval_only=False): # build_model might modify the cfg, thus clone cfg = cfg.clone() model = build_model(cfg) model_ema.may_build_model_ema(cfg, model) if cfg.MODEL.FROZEN_LAYER_REG_EXP: set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False) model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP) if cfg.QUANTIZATION.QAT.ENABLED: # Disable fake_quant and observer so that the model will be trained normally # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER). model = setup_qat_model( cfg, model, enable_fake_quant=eval_only, enable_observer=False ) if eval_only: checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) checkpointer.load(cfg.MODEL.WEIGHTS) model.eval() if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY: model_ema.apply_model_ema(model) # Note: the _visualize_model API is experimental if comm.is_main_process(): if hasattr(model, "_visualize_model"): logger.info("Adding model visualization ...") tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) model._visualize_model(tbx_writer) return model
def _get_model_with_abnormal_checker(model): if not cfg.ABNORMAL_CHECKER.ENABLED: return model tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) writers = abnormal_checker.get_writers(cfg, tbx_writer) checker = abnormal_checker.AbnormalLossChecker(start_iter, writers) ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker) return ret
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs): mapper = mapper or cls.get_mapper(cfg, is_train=True) data_loader = build_d2go_train_loader(cfg, mapper) if comm.is_main_process(): data_loader_type = cls.get_data_loader_vis_wrapper() if data_loader_type is not None: tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) data_loader = data_loader_type(cfg, tbx_writer, data_loader) return data_loader
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs): logger.info("Building detection train loader ...") mapper = mapper or cls.get_mapper(cfg, is_train=True) logger.info("Using dataset mapper:\n{}".format(mapper)) sampler_name = cfg.DATALOADER.SAMPLER_TRAIN if sampler_name == "WeightedTrainingSampler": data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper) else: data_loader = d2_build_detection_train_loader( cfg, *args, mapper=mapper, **kwargs ) if comm.is_main_process(): tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) data_loader = cls.get_data_loader_vis_wrapper()( cfg, tbx_writer, data_loader ) return data_loader
def _setup_visualization_evaluator( evaluator, dataset_name: str, model_tag: ModelTag, ) -> None: logger.info("Adding visualization evaluator ...") mapper = self.get_mapper(self.cfg, is_train=False) vis_eval_type = self.get_visualization_evaluator() # TODO: replace tbx_writter with Lightning's self.logger.experiment tbx_writter = _get_tbx_writer( get_tensorboard_log_dir(self.cfg.OUTPUT_DIR)) if vis_eval_type is not None: evaluator._evaluators.append( vis_eval_type( self.cfg, tbx_writter, mapper, dataset_name, train_iter=self.trainer.global_step, tag_postfix=model_tag, ))
def do_train(self, cfg, model, resume): add_print_flops_callback(cfg, model, disable_after_callback=True) optimizer = self.build_optimizer(cfg, model) scheduler = self.build_lr_scheduler(cfg, optimizer) checkpointer = self.build_checkpointer( cfg, model, save_dir=cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler, ) checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume) start_iter = (checkpoint.get("iteration", -1) if resume and checkpointer.has_checkpoint() else -1) # The checkpoint stores the training iteration that just finished, thus we start # at the next iteration (or iter zero if there's no checkpoint). start_iter += 1 max_iter = cfg.SOLVER.MAX_ITER periodic_checkpointer = PeriodicCheckpointer( checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter) data_loader = self.build_detection_train_loader(cfg) def _get_model_with_abnormal_checker(model): if not cfg.ABNORMAL_CHECKER.ENABLED: return model tbx_writer = _get_tbx_writer( get_tensorboard_log_dir(cfg.OUTPUT_DIR)) writers = abnormal_checker.get_writers(cfg, tbx_writer) checker = abnormal_checker.AbnormalLossChecker(start_iter, writers) ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker) return ret trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( _get_model_with_abnormal_checker(model), data_loader, optimizer) trainer_hooks = [ hooks.IterationTimer(), model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None, self._create_after_step_hook(cfg, model, optimizer, scheduler, periodic_checkpointer), hooks.EvalHook( cfg.TEST.EVAL_PERIOD, lambda: self.do_test(cfg, model, train_iter=trainer.iter), ), kmeans_anchors.compute_kmeans_anchors_hook(self, cfg), self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None, ] if comm.is_main_process(): tbx_writer = _get_tbx_writer( get_tensorboard_log_dir(cfg.OUTPUT_DIR)) writers = [ CommonMetricPrinter(max_iter), JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), tbx_writer, ] trainer_hooks.append(hooks.PeriodicWriter(writers)) trainer.register_hooks(trainer_hooks) trainer.train(start_iter, max_iter) if hasattr(self, 'original_cfg'): table = get_cfg_diff_table(cfg, self.original_cfg) logger.info( "GeneralizeRCNN Runner ignoring training config change: \n" + table) trained_cfg = self.original_cfg.clone() else: trained_cfg = cfg.clone() with temp_defrost(trained_cfg): trained_cfg.MODEL.WEIGHTS = checkpointer.get_checkpoint_file() return {"model_final": trained_cfg}
def _do_test(self, cfg, model, train_iter=None, model_tag="default"): """train_iter: Current iteration of the model, None means final iteration""" assert len(cfg.DATASETS.TEST) assert cfg.OUTPUT_DIR is_final = (train_iter is None) or (train_iter == cfg.SOLVER.MAX_ITER - 1) logger.info( f"Running evaluation for model tag {model_tag} at iter {train_iter}..." ) def _get_inference_dir_name(base_dir, inference_type, dataset_name): return os.path.join( base_dir, inference_type, model_tag, str(train_iter) if train_iter is not None else "final", dataset_name, ) add_print_flops_callback(cfg, model, disable_after_callback=True) results = OrderedDict() results[model_tag] = OrderedDict() for dataset_name in cfg.DATASETS.TEST: # Evaluator will create output folder, no need to create here output_folder = _get_inference_dir_name(cfg.OUTPUT_DIR, "inference", dataset_name) # NOTE: creating evaluator after dataset is loaded as there might be dependency. # noqa data_loader = self.build_detection_test_loader(cfg, dataset_name) evaluator = self.get_evaluator(cfg, dataset_name, output_folder=output_folder) if not isinstance(evaluator, DatasetEvaluators): evaluator = DatasetEvaluators([evaluator]) if comm.is_main_process(): tbx_writer = _get_tbx_writer( get_tensorboard_log_dir(cfg.OUTPUT_DIR)) logger.info("Adding visualization evaluator ...") mapper = self.get_mapper(cfg, is_train=False) evaluator._evaluators.append( self.get_visualization_evaluator()( cfg, tbx_writer, mapper, dataset_name, train_iter=train_iter, tag_postfix=model_tag, )) results_per_dataset = inference_on_dataset(model, data_loader, evaluator) if comm.is_main_process(): results[model_tag][dataset_name] = results_per_dataset if is_final: print_csv_format(results_per_dataset) if is_final and cfg.TEST.AUG.ENABLED: # In the end of training, run an evaluation with TTA # Only support some R-CNN models. output_folder = _get_inference_dir_name( cfg.OUTPUT_DIR, "inference_TTA", dataset_name) logger.info( "Running inference with test-time augmentation ...") data_loader = self.build_detection_test_loader( cfg, dataset_name, mapper=lambda x: x) evaluator = self.get_evaluator(cfg, dataset_name, output_folder=output_folder) inference_on_dataset(GeneralizedRCNNWithTTA(cfg, model), data_loader, evaluator) if is_final and cfg.TEST.EXPECTED_RESULTS and comm.is_main_process(): assert len( results ) == 1, "Results verification only supports one dataset!" verify_results(cfg, results[model_tag][cfg.DATASETS.TEST[0]]) # write results to tensorboard if comm.is_main_process() and results: from detectron2.evaluation.testing import flatten_results_dict flattened_results = flatten_results_dict(results) for k, v in flattened_results.items(): tbx_writer = _get_tbx_writer( get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer._writer.add_scalar("eval_{}".format(k), v, train_iter) if comm.is_main_process(): tbx_writer = _get_tbx_writer( get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer._writer.flush() return results
def get_tbx_writer(cls, cfg): return _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))