示例#1
0
    def register_training_hooks(self,
                                lr_config,
                                optimizer_config=None,
                                checkpoint_config=None,
                                log_config=None):
        """Register default hooks for training.

        Default hooks include:

        - LrUpdaterHook
        - OptimizerStepperHook
        - CheckpointSaverHook
        - IterTimerHook
        - LoggerHook(s)
        """
        if optimizer_config is None:
            optimizer_config = {}
        if checkpoint_config is None:
            checkpoint_config = {}
        self.register_lr_hooks(lr_config)
        self.register_hook(self.build_hook(optimizer_config, OptimizerHook))
        self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
        self.register_hook(IterTimerHook())
        if log_config is not None:
            self.register_logger_hooks(log_config)
示例#2
0
 def register_training_hooks(self,
                             lr_config_b,
                             lr_config_g=None,
                             lr_config_d=None,
                             optimizer_b_config=None,
                             optimizer_g_config=None,
                             optimizer_d_config=None,
                             checkpoint_config=None,
                             log_config=None,
                             momentum_config=None,
                             e2e_training=False):
     self.register_lr_hook(lr_config_b, type='B')
     self.register_lr_hook(lr_config_g, type='G')
     self.register_lr_hook(lr_config_d, type='D')
     self.register_momentum_hook(momentum_config)
     if e2e_training:
         self.register_optimizer_hook(optimizer_b_config,
                                      priority="HIGH",
                                      optim_type="OptimHookB")
     self.register_optimizer_hook(optimizer_g_config,
                                  priority="NORMAL",
                                  optim_type="OptimHookG")
     self.register_optimizer_hook(optimizer_d_config,
                                  priority="LOW",
                                  optim_type="OptimHookD")
     # self.register_optimizer_hook(optimizer_b_config, priority="NORMAL", optim_type="MultiOptimHook")
     self.register_checkpoint_hook(checkpoint_config)
     self.register_hook(IterTimerHook())
     self.register_logger_hooks(log_config)
示例#3
0
    def register_training_hooks(self,
                                lr_config,
                                weight_optim_config=None,
                                arch_optim_config=None,
                                checkpoint_config=None,
                                log_config=None):
        """Register default hooks for training.

        Default hooks include:

        - LrUpdaterHook
        - Weight/Arch_OptimizerStepperHook
        - CheckpointSaverHook
        - IterTimerHook
        - LoggerHook(s)
        """
        if weight_optim_config is None:
            weight_optim_config = {}
        if arch_optim_config is None:
            arch_optim_config = {}
        if checkpoint_config is None:
            checkpoint_config = {}

        self.register_lr_hooks(lr_config)
        self.register_hook(self.build_hook(weight_optim_config, OptimizerHook))
        self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
        # self.register_hook(ModelInfoHook(self.cfg.model_info_interval), priority='VERY_LOW')
        self.register_hook(DropProcessHook(), priority='LOW')
        self.register_hook(IterTimerHook())

        self.register_arch_hook(
            self.build_hook(arch_optim_config, ArchOptimizerHook))
        self.register_arch_hook(ModelInfoHook(self.cfg.model_info_interval),
                                priority='VERY_LOW')
        self.register_arch_hook(DropProcessHook(), priority='LOW')
        self.register_arch_hook(IterTimerHook())
        if log_config is not None:
            self.register_logger_hooks(
                log_config)  # logger_hook for arch_hook will be added inside
示例#4
0
    def register_training_hooks(self, lr_config, optimizer_config=None):
        """Register default hooks for training.

        Default hooks include:

        - LrUpdaterHook
        - OptimizerStepperHook
        - IterTimerHook
        """
        if optimizer_config is None:
            optimizer_config = {}
        self.register_lr_hooks(lr_config)
        self.register_hook(self.build_hook(optimizer_config, OptimizerHook))
        self.register_hook(IterTimerHook())
示例#5
0
def test_cosine_cooldown_hook(multi_optimziers):
    """xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner(multi_optimziers=multi_optimziers)

    # add momentum LR scheduler
    hook_cfg = dict(type='CosineAnnealingCooldownLrUpdaterHook',
                    by_epoch=False,
                    cool_down_time=2,
                    cool_down_ratio=0.1,
                    min_lr_ratio=0.1,
                    warmup_iters=2,
                    warmup_ratio=0.9)
    runner.register_hook_from_cfg(hook_cfg)
    runner.register_hook_from_cfg(dict(type='IterTimerHook'))
    runner.register_hook(IterTimerHook())

    if multi_optimziers:
        check_hook = ValueCheckHook({
            0: {
                'current_lr()["model1"][0]': 0.02,
                'current_lr()["model2"][0]': 0.01,
            },
            5: {
                'current_lr()["model1"][0]': 0.0075558491,
                'current_lr()["model2"][0]': 0.0037779246,
            },
            9: {
                'current_lr()["model1"][0]': 0.0002,
                'current_lr()["model2"][0]': 0.0001,
            }
        })
    else:
        check_hook = ValueCheckHook({
            0: {
                'current_lr()[0]': 0.02,
            },
            5: {
                'current_lr()[0]': 0.0075558491,
            },
            9: {
                'current_lr()[0]': 0.0002,
            }
        })
    runner.register_hook(check_hook, priority='LOWEST')

    runner.run([loader], [('train', 1)])
    shutil.rmtree(runner.work_dir)
示例#6
0
def test_register_timer_hook(runner_class):
    model = Model()
    runner = runner_class(model=model, logger=logging.getLogger())

    # test register None
    timer_config = None
    runner.register_timer_hook(timer_config)
    assert len(runner.hooks) == 0

    # test register IterTimerHook with config
    timer_config = dict(type='IterTimerHook')
    runner.register_timer_hook(timer_config)
    assert len(runner.hooks) == 1
    assert isinstance(runner.hooks[0], IterTimerHook)

    # test register IterTimerHook
    timer_config = IterTimerHook()
    runner.register_timer_hook(timer_config)
    assert len(runner.hooks) == 2
    assert isinstance(runner.hooks[1], IterTimerHook)