Ejemplo n.º 1
0
    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.

        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        ## insert custom hook
        ret.insert(
            -1,
            LossEvalHook(
                self.cfg.TEST.EVAL_PERIOD, self.model,
                build_detection_test_loader(self.cfg,
                                            self.cfg.DATASETS.TEST[0],
                                            DatasetMapper(self.cfg, True))))
        ## change dataset index (custom)
        cfg.Test_index += 1
        if cfg.Test_index == cfg.Test_index_MAX:
            cfg.Test_index = 0

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
        return ret
 def build_hooks(self):
     hooks = super().build_hooks()
     hooks.insert(
         -1,
         LossEvalHook(
             self.cfg, self.val_period, self.model, self.scheduler,
             build_detection_test_loader(self.cfg, self.val_data,
                                         DatasetMapper(self.cfg, True))))
     return hooks
 def build_hooks(self):
     hooks = super().build_hooks()
     hooks.insert(
         -1,
         LossEvalHook(
             cfg.TEST.EVAL_PERIOD, self.model,
             build_detection_test_loader(
                 self.cfg, self.cfg.DATASETS.TEST[0],
                 DatasetMapper(self.cfg, True))))
     return hooks