コード例 #1
0
ファイル: default_runner.py プロジェクト: ananthsub/d2go
    def do_train(self, cfg, model, resume):
        add_print_flops_callback(cfg, model, disable_after_callback=True)

        optimizer = self.build_optimizer(cfg, model)
        scheduler = self.build_lr_scheduler(cfg, optimizer)

        checkpointer = self.build_checkpointer(
            cfg,
            model,
            save_dir=cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=scheduler,
        )
        checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS,
                                                 resume=resume)
        start_iter = (checkpoint.get("iteration", -1)
                      if resume and checkpointer.has_checkpoint() else -1)
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        start_iter += 1
        max_iter = cfg.SOLVER.MAX_ITER
        periodic_checkpointer = PeriodicCheckpointer(
            checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)

        data_loader = self.build_detection_train_loader(cfg)

        def _get_model_with_abnormal_checker(model):
            if not cfg.ABNORMAL_CHECKER.ENABLED:
                return model

            tbx_writer = _get_tbx_writer(
                get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            writers = abnormal_checker.get_writers(cfg, tbx_writer)
            checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
            ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
            return ret

        trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
            _get_model_with_abnormal_checker(model), data_loader, optimizer)
        trainer_hooks = [
            hooks.IterationTimer(),
            model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
            self._create_after_step_hook(cfg, model, optimizer, scheduler,
                                         periodic_checkpointer),
            hooks.EvalHook(
                cfg.TEST.EVAL_PERIOD,
                lambda: self.do_test(cfg, model, train_iter=trainer.iter),
            ),
            kmeans_anchors.compute_kmeans_anchors_hook(self, cfg),
            self._create_qat_hook(cfg)
            if cfg.QUANTIZATION.QAT.ENABLED else None,
        ]

        if comm.is_main_process():
            tbx_writer = _get_tbx_writer(
                get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            writers = [
                CommonMetricPrinter(max_iter),
                JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
                tbx_writer,
            ]
            trainer_hooks.append(hooks.PeriodicWriter(writers))
        trainer.register_hooks(trainer_hooks)
        trainer.train(start_iter, max_iter)

        if hasattr(self, 'original_cfg'):
            table = get_cfg_diff_table(cfg, self.original_cfg)
            logger.info(
                "GeneralizeRCNN Runner ignoring training config change: \n" +
                table)
            trained_cfg = self.original_cfg.clone()
        else:
            trained_cfg = cfg.clone()
        with temp_defrost(trained_cfg):
            trained_cfg.MODEL.WEIGHTS = checkpointer.get_checkpoint_file()
        return {"model_final": trained_cfg}
コード例 #2
0
ファイル: default_runner.py プロジェクト: ananthsub/d2go
    def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
        """train_iter: Current iteration of the model, None means final iteration"""
        assert len(cfg.DATASETS.TEST)
        assert cfg.OUTPUT_DIR

        is_final = (train_iter is None) or (train_iter
                                            == cfg.SOLVER.MAX_ITER - 1)

        logger.info(
            f"Running evaluation for model tag {model_tag} at iter {train_iter}..."
        )

        def _get_inference_dir_name(base_dir, inference_type, dataset_name):
            return os.path.join(
                base_dir,
                inference_type,
                model_tag,
                str(train_iter) if train_iter is not None else "final",
                dataset_name,
            )

        add_print_flops_callback(cfg, model, disable_after_callback=True)

        results = OrderedDict()
        results[model_tag] = OrderedDict()
        for dataset_name in cfg.DATASETS.TEST:
            # Evaluator will create output folder, no need to create here
            output_folder = _get_inference_dir_name(cfg.OUTPUT_DIR,
                                                    "inference", dataset_name)

            # NOTE: creating evaluator after dataset is loaded as there might be dependency.  # noqa
            data_loader = self.build_detection_test_loader(cfg, dataset_name)
            evaluator = self.get_evaluator(cfg,
                                           dataset_name,
                                           output_folder=output_folder)

            if not isinstance(evaluator, DatasetEvaluators):
                evaluator = DatasetEvaluators([evaluator])
            if comm.is_main_process():
                tbx_writer = _get_tbx_writer(
                    get_tensorboard_log_dir(cfg.OUTPUT_DIR))
                logger.info("Adding visualization evaluator ...")
                mapper = self.get_mapper(cfg, is_train=False)
                evaluator._evaluators.append(
                    self.get_visualization_evaluator()(
                        cfg,
                        tbx_writer,
                        mapper,
                        dataset_name,
                        train_iter=train_iter,
                        tag_postfix=model_tag,
                    ))

            results_per_dataset = inference_on_dataset(model, data_loader,
                                                       evaluator)

            if comm.is_main_process():
                results[model_tag][dataset_name] = results_per_dataset
                if is_final:
                    print_csv_format(results_per_dataset)

            if is_final and cfg.TEST.AUG.ENABLED:
                # In the end of training, run an evaluation with TTA
                # Only support some R-CNN models.
                output_folder = _get_inference_dir_name(
                    cfg.OUTPUT_DIR, "inference_TTA", dataset_name)

                logger.info(
                    "Running inference with test-time augmentation ...")
                data_loader = self.build_detection_test_loader(
                    cfg, dataset_name, mapper=lambda x: x)
                evaluator = self.get_evaluator(cfg,
                                               dataset_name,
                                               output_folder=output_folder)
                inference_on_dataset(GeneralizedRCNNWithTTA(cfg, model),
                                     data_loader, evaluator)

        if is_final and cfg.TEST.EXPECTED_RESULTS and comm.is_main_process():
            assert len(
                results
            ) == 1, "Results verification only supports one dataset!"
            verify_results(cfg, results[model_tag][cfg.DATASETS.TEST[0]])

        # write results to tensorboard
        if comm.is_main_process() and results:
            from detectron2.evaluation.testing import flatten_results_dict

            flattened_results = flatten_results_dict(results)
            for k, v in flattened_results.items():
                tbx_writer = _get_tbx_writer(
                    get_tensorboard_log_dir(cfg.OUTPUT_DIR))
                tbx_writer._writer.add_scalar("eval_{}".format(k), v,
                                              train_iter)

        if comm.is_main_process():
            tbx_writer = _get_tbx_writer(
                get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            tbx_writer._writer.flush()
        return results