Esempio n. 1
0
def main(args):
    config.merge_from_list(args.opts)
    cfg, logger = default_setup(config, args)
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the runner.
    """
    runner = runner_decrator(RUNNERS.get(cfg.TRAINER.NAME))(cfg, build_model)
    runner.resume_or_load(resume=args.resume)

    # check wheather worksapce has enough storeage space
    # assume that a single dumped model is 700Mb
    file_sys = os.statvfs(cfg.OUTPUT_DIR)
    free_space_Gb = (file_sys.f_bfree * file_sys.f_frsize) / 2**30
    eval_space_Gb = (cfg.SOLVER.LR_SCHEDULER.MAX_ITER //
                     cfg.SOLVER.CHECKPOINT_PERIOD) * 700 / 2**10
    if eval_space_Gb > free_space_Gb:
        logger.warning(f"{Fore.RED}Remaining space({free_space_Gb}GB) "
                       f"is less than ({eval_space_Gb}GB){Style.RESET_ALL}")

    if cfg.TEST.AUG.ENABLED:
        runner.register_hooks([
            hooks.EvalHook(0, lambda: runner.test_with_TTA(cfg, runner.model))
        ])

    logger.info("Running with full config:\n{}".format(cfg))
    base_config = cfg.__class__.__base__()
    logger.info("different config with base class:\n{}".format(
        cfg.diff(base_config)))

    if args.eval_only:
        runner.test(cfg, runner.model)
        return
    runner.train()
Esempio n. 2
0
def main(args, config, build_model):
    config.merge_from_list(args.opts)
    cfg = default_setup(config, args)
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the runner.
    """
    runner = runner_decrator(RUNNERS.get(cfg.TRAINER.NAME))(cfg, build_model)
    runner.resume_or_load(resume=args.resume)

    extra_hooks = []
    if args.clearml:
        from cvpods.engine.clearml import ClearMLHook
        if comm.is_main_process():
            extra_hooks.append(ClearMLHook())
    if cfg.TEST.AUG.ENABLED:
        extra_hooks.append(
            hooks.EvalHook(0, lambda: runner.test_with_TTA(cfg, runner.model)))
    if extra_hooks:
        runner.register_hooks(extra_hooks)

    logger.info("Running with full config:\n{}".format(cfg))
    base_config = cfg.__class__.__base__()
    logger.info("different config with base class:\n{}".format(
        cfg.diff(base_config)))

    runner.train()

    if comm.is_main_process() and cfg.MODEL.AS_PRETRAIN:
        # convert last ckpt to pretrain format
        convert_to_pretrained_model(input=os.path.join(cfg.OUTPUT_DIR,
                                                       "model_final.pth"),
                                    save_path=os.path.join(
                                        cfg.OUTPUT_DIR,
                                        "model_final_pretrain_weight.pkl"))
Esempio n. 3
0
def stage_main(args, cfg, build):
    cfg.merge_from_list(args.opts)
    cfg, logger = default_setup(cfg, args)
    model_build_func = build
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg, model_build_func)
    trainer.resume_or_load(resume=args.resume)

    if args.eval_only:
        DefaultCheckpointer(trainer.model,
                            save_dir=cfg.OUTPUT_DIR,
                            resume=args.resume).resume_or_load(
                                cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, trainer.model)
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, trainer.model))
        return res

    # check wheather worksapce has enough storeage space
    # assume that a single dumped model is 700Mb
    file_sys = os.statvfs(cfg.OUTPUT_DIR)
    free_space_Gb = (file_sys.f_bfree * file_sys.f_frsize) / 2**30
    eval_space_Gb = (cfg.SOLVER.LR_SCHEDULER.MAX_ITER //
                     cfg.SOLVER.CHECKPOINT_PERIOD) * 700 / 2**10
    if eval_space_Gb > free_space_Gb:
        logger.warning(f"{Fore.RED}Remaining space({free_space_Gb}GB) "
                       f"is less than ({eval_space_Gb}GB){Style.RESET_ALL}")

    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])

    trainer.train()

    if comm.is_main_process() and cfg.MODEL.AS_PRETRAIN:
        # convert last ckpt to pretrain format
        convert_to_pretrained_model(input=os.path.join(cfg.OUTPUT_DIR,
                                                       "model_final.pth"),
                                    save_path=os.path.join(
                                        cfg.OUTPUT_DIR,
                                        "model_final_pretrain_weight.pkl"))
Esempio n. 4
0
    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.

        Returns:
            list[HookBase]:
        """
        cfg = self.cfg
        # cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            SwavOptimizationHook(
                accumulate_grad_steps=cfg.SOLVER.BATCH_SUBDIVISIONS,
                grad_clipper=None,
                mixed_precision=cfg.TRAINER.FP16.ENABLED,
                cancel_epochs=cfg.MODEL.SWAV.CANCEL_EPOCHS,
            ),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.IterationTimer(),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD,
                                           max_iter=self.max_iter,
                                           max_epoch=self.max_epoch))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # Here the default print/log frequency of each writer is used.
            # run writers in the end, so that evaluation metrics are written
            ret.append(
                hooks.PeriodicWriter(self.build_writers(),
                                     period=self.window_size))
            # Put `PeriodicDumpLog` after writers so that can dump all the files,
            # including the files generated by writers
            if cfg.OSS.DUMP_LOG_ENABLED:
                if cfg.OSS.DUMP_PERIOD == 0:
                    dump_log_period = cfg.SOLVER.CHECKPOINT_PERIOD
                else:
                    dump_log_period = cfg.OSS.DUMP_PERIOD
                ret.append(
                    hooks.PeriodicDumpLog(cfg.OUTPUT_DIR, cfg.OSS.DUMP_PREFIX,
                                          dump_log_period))
        return ret