Пример #1
0
    def build_model(self, cfg, eval_only=False):
        # build_model might modify the cfg, thus clone
        cfg = cfg.clone()

        model = build_model(cfg)
        model_ema.may_build_model_ema(cfg, model)

        if cfg.MODEL.FROZEN_LAYER_REG_EXP:
            set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
            model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP)

        if cfg.QUANTIZATION.QAT.ENABLED:
            # Disable fake_quant and observer so that the model will be trained normally
            # before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
            model = setup_qat_model(
                cfg, model, enable_fake_quant=eval_only, enable_observer=False
            )

        if eval_only:
            checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
            checkpointer.load(cfg.MODEL.WEIGHTS)
            model.eval()

            if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
                model_ema.apply_model_ema(model)

        # Note: the _visualize_model API is experimental
        if comm.is_main_process():
            if hasattr(model, "_visualize_model"):
                logger.info("Adding model visualization ...")
                tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
                model._visualize_model(tbx_writer)

        return model
Пример #2
0
        def _get_model_with_abnormal_checker(model):
            if not cfg.ABNORMAL_CHECKER.ENABLED:
                return model

            tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            writers = abnormal_checker.get_writers(cfg, tbx_writer)
            checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
            ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
            return ret
Пример #3
0
    def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
        mapper = mapper or cls.get_mapper(cfg, is_train=True)
        data_loader = build_d2go_train_loader(cfg, mapper)

        if comm.is_main_process():
            data_loader_type = cls.get_data_loader_vis_wrapper()
            if data_loader_type is not None:
                tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
                data_loader = data_loader_type(cfg, tbx_writer, data_loader)
        return data_loader
Пример #4
0
    def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
        logger.info("Building detection train loader ...")
        mapper = mapper or cls.get_mapper(cfg, is_train=True)
        logger.info("Using dataset mapper:\n{}".format(mapper))

        sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
        if sampler_name == "WeightedTrainingSampler":
            data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
        else:
            data_loader = d2_build_detection_train_loader(
                cfg, *args, mapper=mapper, **kwargs
            )

        if comm.is_main_process():
            tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            data_loader = cls.get_data_loader_vis_wrapper()(
                cfg, tbx_writer, data_loader
            )
        return data_loader
Пример #5
0
 def _setup_visualization_evaluator(
     evaluator,
     dataset_name: str,
     model_tag: ModelTag,
 ) -> None:
     logger.info("Adding visualization evaluator ...")
     mapper = self.get_mapper(self.cfg, is_train=False)
     vis_eval_type = self.get_visualization_evaluator()
     # TODO: replace tbx_writter with Lightning's self.logger.experiment
     tbx_writter = _get_tbx_writer(
         get_tensorboard_log_dir(self.cfg.OUTPUT_DIR))
     if vis_eval_type is not None:
         evaluator._evaluators.append(
             vis_eval_type(
                 self.cfg,
                 tbx_writter,
                 mapper,
                 dataset_name,
                 train_iter=self.trainer.global_step,
                 tag_postfix=model_tag,
             ))
Пример #6
0
    def do_train(self, cfg, model, resume):
        add_print_flops_callback(cfg, model, disable_after_callback=True)

        optimizer = self.build_optimizer(cfg, model)
        scheduler = self.build_lr_scheduler(cfg, optimizer)

        checkpointer = self.build_checkpointer(
            cfg,
            model,
            save_dir=cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=scheduler,
        )
        checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS,
                                                 resume=resume)
        start_iter = (checkpoint.get("iteration", -1)
                      if resume and checkpointer.has_checkpoint() else -1)
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        start_iter += 1
        max_iter = cfg.SOLVER.MAX_ITER
        periodic_checkpointer = PeriodicCheckpointer(
            checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)

        data_loader = self.build_detection_train_loader(cfg)

        def _get_model_with_abnormal_checker(model):
            if not cfg.ABNORMAL_CHECKER.ENABLED:
                return model

            tbx_writer = _get_tbx_writer(
                get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            writers = abnormal_checker.get_writers(cfg, tbx_writer)
            checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
            ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
            return ret

        trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
            _get_model_with_abnormal_checker(model), data_loader, optimizer)
        trainer_hooks = [
            hooks.IterationTimer(),
            model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
            self._create_after_step_hook(cfg, model, optimizer, scheduler,
                                         periodic_checkpointer),
            hooks.EvalHook(
                cfg.TEST.EVAL_PERIOD,
                lambda: self.do_test(cfg, model, train_iter=trainer.iter),
            ),
            kmeans_anchors.compute_kmeans_anchors_hook(self, cfg),
            self._create_qat_hook(cfg)
            if cfg.QUANTIZATION.QAT.ENABLED else None,
        ]

        if comm.is_main_process():
            tbx_writer = _get_tbx_writer(
                get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            writers = [
                CommonMetricPrinter(max_iter),
                JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
                tbx_writer,
            ]
            trainer_hooks.append(hooks.PeriodicWriter(writers))
        trainer.register_hooks(trainer_hooks)
        trainer.train(start_iter, max_iter)

        if hasattr(self, 'original_cfg'):
            table = get_cfg_diff_table(cfg, self.original_cfg)
            logger.info(
                "GeneralizeRCNN Runner ignoring training config change: \n" +
                table)
            trained_cfg = self.original_cfg.clone()
        else:
            trained_cfg = cfg.clone()
        with temp_defrost(trained_cfg):
            trained_cfg.MODEL.WEIGHTS = checkpointer.get_checkpoint_file()
        return {"model_final": trained_cfg}
Пример #7
0
    def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
        """train_iter: Current iteration of the model, None means final iteration"""
        assert len(cfg.DATASETS.TEST)
        assert cfg.OUTPUT_DIR

        is_final = (train_iter is None) or (train_iter
                                            == cfg.SOLVER.MAX_ITER - 1)

        logger.info(
            f"Running evaluation for model tag {model_tag} at iter {train_iter}..."
        )

        def _get_inference_dir_name(base_dir, inference_type, dataset_name):
            return os.path.join(
                base_dir,
                inference_type,
                model_tag,
                str(train_iter) if train_iter is not None else "final",
                dataset_name,
            )

        add_print_flops_callback(cfg, model, disable_after_callback=True)

        results = OrderedDict()
        results[model_tag] = OrderedDict()
        for dataset_name in cfg.DATASETS.TEST:
            # Evaluator will create output folder, no need to create here
            output_folder = _get_inference_dir_name(cfg.OUTPUT_DIR,
                                                    "inference", dataset_name)

            # NOTE: creating evaluator after dataset is loaded as there might be dependency.  # noqa
            data_loader = self.build_detection_test_loader(cfg, dataset_name)
            evaluator = self.get_evaluator(cfg,
                                           dataset_name,
                                           output_folder=output_folder)

            if not isinstance(evaluator, DatasetEvaluators):
                evaluator = DatasetEvaluators([evaluator])
            if comm.is_main_process():
                tbx_writer = _get_tbx_writer(
                    get_tensorboard_log_dir(cfg.OUTPUT_DIR))
                logger.info("Adding visualization evaluator ...")
                mapper = self.get_mapper(cfg, is_train=False)
                evaluator._evaluators.append(
                    self.get_visualization_evaluator()(
                        cfg,
                        tbx_writer,
                        mapper,
                        dataset_name,
                        train_iter=train_iter,
                        tag_postfix=model_tag,
                    ))

            results_per_dataset = inference_on_dataset(model, data_loader,
                                                       evaluator)

            if comm.is_main_process():
                results[model_tag][dataset_name] = results_per_dataset
                if is_final:
                    print_csv_format(results_per_dataset)

            if is_final and cfg.TEST.AUG.ENABLED:
                # In the end of training, run an evaluation with TTA
                # Only support some R-CNN models.
                output_folder = _get_inference_dir_name(
                    cfg.OUTPUT_DIR, "inference_TTA", dataset_name)

                logger.info(
                    "Running inference with test-time augmentation ...")
                data_loader = self.build_detection_test_loader(
                    cfg, dataset_name, mapper=lambda x: x)
                evaluator = self.get_evaluator(cfg,
                                               dataset_name,
                                               output_folder=output_folder)
                inference_on_dataset(GeneralizedRCNNWithTTA(cfg, model),
                                     data_loader, evaluator)

        if is_final and cfg.TEST.EXPECTED_RESULTS and comm.is_main_process():
            assert len(
                results
            ) == 1, "Results verification only supports one dataset!"
            verify_results(cfg, results[model_tag][cfg.DATASETS.TEST[0]])

        # write results to tensorboard
        if comm.is_main_process() and results:
            from detectron2.evaluation.testing import flatten_results_dict

            flattened_results = flatten_results_dict(results)
            for k, v in flattened_results.items():
                tbx_writer = _get_tbx_writer(
                    get_tensorboard_log_dir(cfg.OUTPUT_DIR))
                tbx_writer._writer.add_scalar("eval_{}".format(k), v,
                                              train_iter)

        if comm.is_main_process():
            tbx_writer = _get_tbx_writer(
                get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            tbx_writer._writer.flush()
        return results
Пример #8
0
 def get_tbx_writer(cls, cfg):
     return _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))