Esempio n. 1
0
    def build_hooks(self) -> List[HookBase]:
        """
        This method overwrites the default one from DefaultTrainer.
        It adds the `LossEvalHook` that allows evaluating results on the validation set.
        The default method builds a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.

        Returns:
            list[HookBase]: The list of hooks to call during training.
        """
        hooks: List[HookBase] = super().build_hooks()

        # TODO remove as it can't work
        # input_example = next(iter(self.data_loader))
        # hooks.append(ModelWriter(model=self.model,
        # input_example=input_example,
        # log_dir=self.cfg.OUTPUT_DIR))

        # We add our custom validation hook
        if self.cfg.DATASETS.VALIDATION != "":
            data_set_mapper: PanelSegDatasetMapper = PanelSegDatasetMapper(
                cfg=self.cfg, is_train=True)
            data_loader: DataLoader = build_detection_test_loader(
                cfg=self.cfg,
                dataset_name=self.cfg.DATASETS.VALIDATION,
                mapper=data_set_mapper)

            loss_eval_hook: LossEvalHook = LossEvalHook(
                eval_period=self.cfg.VALIDATION.VALIDATION_PERIOD,
                model=self.model,
                data_loader=data_loader)

            hooks.insert(-1, loss_eval_hook)

        return hooks
    def build_hooks(self) -> List[HookBase]:
        """
        This method overwrites the default one from DefaultTrainer.
        It adds (if necessary) the `LossEvalHook` that allows evaluating the loss on the
        validation set.

        Returns:
            List[HookBase]: The augmented list of hooks.
        """
        # Build a list of default hooks, including timing, evaluation,
        # checkpointing, lr scheduling, precise BN, writing events.
        hooks = super().build_hooks()

        # We add our custom validation hook
        if self.cfg.DATASETS.VALIDATION != "":
            data_set_mapper: DatasetMapper = DatasetMapper.from_config(
                cfg=self.cfg, is_train=True)

            data_loader: DataLoader = build_detection_test_loader(
                cfg=self.cfg,
                dataset_name=self.cfg.DATASETS.VALIDATION,
                mapper=data_set_mapper)

            loss_eval_hook: LossEvalHook = LossEvalHook(
                eval_period=self.cfg.VALIDATION.VALIDATION_PERIOD,
                model=self.model,
                data_loader=data_loader)

            hooks.insert(index=-1, obj=loss_eval_hook)

        return hooks
Esempio n. 3
0
 def build_test_loader(cls, cfg, dataset_name):
     """
     It now calls :func:`detectron2.data.build_detection_test_loader`.
     Overwrite it if you'd like a different data loader.
     """
     mapr = cls.build_mapper(cfg, is_train=False)
     return build_detection_test_loader(cfg, dataset_name, mapper=None)
Esempio n. 4
0
 def build_hooks(self):
     hooks = super().build_hooks()
     if self.cfg.ISPRS.LABEL.BOXMODE == "ROTATED":
         eval_mapper = ISPRSCOCOStyleMapperRotated
     else:
         eval_mapper = ISPRSCOCOStyleMapperAxisAligned
     hooks.insert(
         -1,
         LossEvalHook(
             self.cfg.TEST.EVAL_PERIOD, self._trainer.model,
             build_detection_test_loader(self.cfg,
                                         self.cfg.DATASETS.TEST[0],
                                         eval_mapper(self.cfg, True))))
     return hooks
Esempio n. 5
0
    def build_test_loader(cls, cfg: CfgNode, dataset_name: str) -> DataLoader:
        """
        Instanciate the test data loader.

        Args:
            cfg (CfgNode):      The global config.
            dataset_name (str): The name of the test dataset.

        Returns:
            a DataLoader yielding formatted test examples.
        """
        mapper: PanelSegDatasetMapper = PanelSegDatasetMapper(cfg,
                                                              is_train=False)
        return build_detection_test_loader(cfg,
                                           dataset_name=dataset_name,
                                           mapper=mapper)
Esempio n. 6
0
 def build_test_loader(cls, cfg, dataset_name):
     return build_detection_test_loader(cfg,
                                        dataset_name,
                                        mapper=DatasetMapper(cfg, False))
Esempio n. 7
0
def custom_test_loader(cfg, dataset_name):
    return build_detection_test_loader(cfg,
                                       dataset_name,
                                       mapper=CustomDatasetMapper(cfg, False))
Esempio n. 8
0
def get_lvis_test_dataloader(cfg, h, w):
    default_mapper = DatasetMapper(cfg, is_train=False)
    mapper = partial(wrapper, default_m=default_mapper, h=h, w=w)
    dl = build_detection_test_loader(cfg, 'lvis_v0.5_val', mapper=mapper)
    return dl
Esempio n. 9
0
    trainer.resume_or_load(resume=False)
    trainer.train()

    # Look at training curves in tensorboard:
    # %load_ext tensorboard
    # %tensorboard --logdir output/run_rdd/
else:
    # Inference & evaluation using the trained model
    # Now, let's run inference with the trained model on the validation dataset.
    # First, let's create a predictor using the model we just trained:
    cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7  # set a custom testing threshold for this model
    cfg.DATASETS.TEST = ("rdd2020_val", )
    predictor = DefaultPredictor(cfg)
    trainer.resume_or_load(resume=True)

    # Then, we randomly select several samples to visualize the prediction results.
    from detectron2.utils.visualizer import ColorMode
    from detectron2.evaluation import COCOEvaluator, DatasetEvaluators, inference_on_dataset
    from detectron2.data import build_detection_test_loader

    evaluator = COCOEvaluator("rdd2020_val", cfg, False, "coco_eval")
    val_loader = build_detection_test_loader(cfg, "rdd2020_val")
    eval_results = inference_on_dataset(trainer.model, val_loader,
                                        DatasetEvaluators([evaluator]))
    # another equivalent way is to use trainer.test
    print(eval_results)

# Empty the GPU Memory
torch.cuda.empty_cache()