コード例 #1
0
    def test_checkpoint_resume(self):
        model = _SimpleModel()
        dataloader = self._data_loader("cpu")
        opt = torch.optim.SGD(model.parameters(), 0.1)
        scheduler = torch.optim.lr_scheduler.StepLR(opt, 3)

        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
            trainer = SimpleTrainer(model, dataloader, opt)
            checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)

            trainer.register_hooks(
                [
                    hooks.PeriodicCheckpointer(checkpointer, 10),
                    hooks.LRScheduler(scheduler=scheduler),
                ]
            )

            trainer.train(0, 12)
            del trainer

            trainer = SimpleTrainer(model, dataloader, opt)
            scheduler = torch.optim.lr_scheduler.StepLR(opt, 3)
            trainer.register_hooks(
                [
                    hooks.LRScheduler(scheduler=scheduler),
                ]
            )
            checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)
            checkpointer.resume_or_load("non_exist.pth")
            self.assertEqual(trainer.iter, 11)  # last finished iter
            self.assertEqual(scheduler.last_epoch, 11)
コード例 #2
0
ファイル: test_engine.py プロジェクト: yuanluxu/detectron2
    def test_checkpoint_resume(self):
        model = _SimpleModel()
        dataloader = self._data_loader("cpu")
        opt = torch.optim.SGD(model.parameters(), 0.1)
        scheduler = torch.optim.lr_scheduler.StepLR(opt, 3)

        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
            trainer = SimpleTrainer(model, dataloader, opt)
            checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)

            trainer.register_hooks([
                hooks.LRScheduler(scheduler=scheduler),
                # checkpoint after scheduler to properly save the state of scheduler
                hooks.PeriodicCheckpointer(checkpointer, 10),
            ])

            trainer.train(0, 12)
            self.assertAlmostEqual(opt.param_groups[0]["lr"], 1e-5)
            self.assertEqual(scheduler.last_epoch, 12)
            del trainer

            opt = torch.optim.SGD(model.parameters(), 999)  # lr will be loaded
            trainer = SimpleTrainer(model, dataloader, opt)
            scheduler = torch.optim.lr_scheduler.StepLR(opt, 3)
            trainer.register_hooks([
                hooks.LRScheduler(scheduler=scheduler),
            ])
            checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)
            checkpointer.resume_or_load("non_exist.pth")
            self.assertEqual(
                trainer.iter,
                11)  # last finished iter number (0-based in Trainer)
            # number of times `scheduler.step()` was called (1-based)
            self.assertEqual(scheduler.last_epoch, 12)
            self.assertAlmostEqual(opt.param_groups[0]["lr"], 1e-5)
    def build_hooks(self):
        r"""Build hooks

        We use: timing, lr scheduling, checkpointing, lr scheduling, ValidationLoss, writing events
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))

        ret.append(ValidationLoss(cfg, VAL_TRANSF, cfg.VAL_LOG_PERIOD))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(
                hooks.PeriodicWriter(self.build_writers(),
                                     period=cfg.VAL_LOG_PERIOD))
        return ret
コード例 #4
0
def do_train(cfg):
    model = instantiate(cfg.model)
    logger = logging.getLogger("detectron2")
    logger.info("Model:\n{}".format(model))
    model.to(cfg.train.device)

    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer)

    train_loader = instantiate(cfg.dataloader.train)

    model = create_ddp_model(model, **cfg.train.ddp)
    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(
        model, train_loader, optim)
    checkpointer = DetectionCheckpointer(
        model,
        cfg.train.output_dir,
        optimizer=optim,
        trainer=trainer,
    )
    trainer.register_hooks([
        hooks.IterationTimer(),
        hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
        hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
        if comm.is_main_process() else None,
        hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
        hooks.PeriodicWriter(
            default_writers(cfg.train.output_dir, cfg.train.max_iter),
            period=cfg.train.log_period,
        ) if comm.is_main_process() else None,
    ])

    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=True)
    start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)
コード例 #5
0
ファイル: defaults.py プロジェクト: zyg11/centerX
    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
        logger = logging.getLogger(__name__)
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
        ]

        if cfg.SOLVER.SWA.ENABLED:
            ret.append(
                additional_hooks.SWA(
                    cfg.SOLVER.MAX_ITER,
                    cfg.SOLVER.SWA.PERIOD,
                    cfg.SOLVER.SWA.LR_START,
                    cfg.SOLVER.SWA.ETA_MIN_LR,
                    cfg.SOLVER.SWA.LR_SCHED,
                )
            )

        if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
            logger.info("Prepare precise BN dataset")
            ret.append(hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            ))
        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), 20))

        return ret
コード例 #6
0
def do_train(args, cfg):
    """
    Args:
        cfg: an object with the following attributes:
            model: instantiate to a module
            dataloader.{train,test}: instantiate to dataloaders
            dataloader.evaluator: instantiate to evaluator for test set
            optimizer: instantaite to an optimizer
            lr_multiplier: instantiate to a fvcore scheduler
            train: other misc config defined in `configs/common/train.py`, including:
                output_dir (str)
                init_checkpoint (str)
                amp.enabled (bool)
                max_iter (int)
                eval_period, log_period (int)
                device (str)
                checkpointer (dict)
                ddp (dict)
    """
    model = instantiate(cfg.model)
    logger = logging.getLogger("detectron2")
    logger.info("Model:\n{}".format(model))
    model.to(cfg.train.device)

    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer)

    train_loader = instantiate(cfg.dataloader.train)

    model = create_ddp_model(model, **cfg.train.ddp)
    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(
        model, train_loader, optim)
    checkpointer = DetectionCheckpointer(
        model,
        cfg.train.output_dir,
        optimizer=optim,
        trainer=trainer,
    )
    trainer.register_hooks([
        hooks.IterationTimer(),
        hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
        hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
        if comm.is_main_process() else None,
        hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
        hooks.PeriodicWriter(
            default_writers(cfg.train.output_dir, cfg.train.max_iter),
            period=cfg.train.log_period,
        ) if comm.is_main_process() else None,
    ])

    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
    if args.resume and checkpointer.has_checkpoint():
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration
        start_iter = trainer.iter + 1
    else:
        start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)
コード例 #7
0
ファイル: engine.py プロジェクト: ImYuxuanLi/CDPN-pytorch
    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            if cfg.SOLVER.CHECKPOINT_BY_EPOCH:
                ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD * self.iters_per_epoch
            else:
                ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD
            ret.append(
                MyPeriodicCheckpointer(self.checkpointer,
                                       ckpt_period,
                                       max_to_keep=cfg.SOLVER.get(
                                           "NUM_CKPT_KEEP", 5),
                                       iters_per_epoch=self.iters_per_epoch))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(
                hooks.PeriodicWriter(self.build_writers(),
                                     period=cfg.TRAIN.get("PRINT_FREQ", 100)))
        return ret
コード例 #8
0
    def build_hooks(self):
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            )
            if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(
                    self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD
                )
            )

        def test_and_save_results_student():
            self._last_eval_results_student = self.test(self.cfg, self.model)
            _last_eval_results_student = {
                k + "_student": self._last_eval_results_student[k]
                for k in self._last_eval_results_student.keys()
            }
            return _last_eval_results_student

        def test_and_save_results_teacher():
            self._last_eval_results_teacher = self.test(
                self.cfg, self.model_teacher)
            return self._last_eval_results_teacher

        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD,
                   test_and_save_results_student))
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD,
                   test_and_save_results_teacher))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
        return ret
コード例 #9
0
    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.

        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                cfg.TEST.EVAL_PERIOD,
                self.model,
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            )
            if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(
                    self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD
                )
            )

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
        return ret
コード例 #10
0
    def build_hooks(self):
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))

        # def test_and_save_results():
        #     self._last_eval_results = self.test(self.cfg, self.model)
        #     return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        # ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
        return ret
コード例 #11
0
    def build_hooks(self):
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            )
            if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        if comm.is_main_process():
            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            res = self._last_eval_results = self.test(self.cfg, self.model)
            eval_dir = os.path.join(self.cfg.OUTPUT_DIR, 'evals')
            os.makedirs(eval_dir, exist_ok=True)
            pd.DataFrame(res).to_csv(os.path.join(eval_dir, f'{self.round}.csv'))
            return self._last_eval_results
        
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
        return ret
コード例 #12
0
    def build_hooks(self):
        """
    Build a list of default hooks, including timing, evaluation,
    checkpointing, lr scheduling, precise BN, writing events.

    Returns:
        list[HookBase]:
    """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = \
        [
          hooks.IterationTimer(),
          hooks.LRScheduler(self.optimizer, self.scheduler),
          hooks.PreciseBN(
              # Run at the same freq as (but before) evaluation.
              cfg.TEST.EVAL_PERIOD,
              self.model,
              # Build a new data loader to not affect training
              self.build_train_loader(cfg,self.mapper_object,self.isShuffleData),
              cfg.TEST.PRECISE_BN.NUM_ITER,
          )
          if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
          else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            # self._last_eval_results = self.test(self.cfg, self.model, self.isTrackAccuracy, self.getter, self.dataset_used)
            self._last_eval_results = self.test(self.cfg,
                                                self.model,
                                                self.getter,
                                                self.dataset_used,
                                                self.mapper_object,
                                                evaluators=None)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            numberOfSamples = 25
            step = -1
            if (self.max_iter <= numberOfSamples):
                # Eg, maxiter = 20, so step = 20/2 = 10, take a sample every 10
                step = int(round(float(self.max_iter) / float(2), 2))
            else:
                # Eg 10000/20 = 500, so will take a sample every 500 iterations
                step = float(self.max_iter) / float(numberOfSamples)
                step = int(round(step, 0))
                if (step < 1): step = 1

            # print("!!!!!!!!!!!!!!STEPS: ", step)
            # ret.append(hooks.PeriodicWriter(self.build_writers()))
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=step))
            # ret.append(hooks.PeriodicWriter(self.build_writers(),period=(self.max_iter-1)))
        return ret