def register_training_hooks(self, lr_config, optimizer_config=None, checkpoint_config=None, log_config=None): """Register default hooks for training. Default hooks include: - LrUpdaterHook - OptimizerStepperHook - CheckpointSaverHook - IterTimerHook - LoggerHook(s) """ if optimizer_config is None: optimizer_config = {} if checkpoint_config is None: checkpoint_config = {} self.register_lr_hooks(lr_config) self.register_hook(self.build_hook(optimizer_config, OptimizerHook)) self.register_hook(self.build_hook(checkpoint_config, CheckpointHook)) self.register_hook(IterTimerHook()) if log_config is not None: self.register_logger_hooks(log_config)
def register_training_hooks(self, lr_config_b, lr_config_g=None, lr_config_d=None, optimizer_b_config=None, optimizer_g_config=None, optimizer_d_config=None, checkpoint_config=None, log_config=None, momentum_config=None, e2e_training=False): self.register_lr_hook(lr_config_b, type='B') self.register_lr_hook(lr_config_g, type='G') self.register_lr_hook(lr_config_d, type='D') self.register_momentum_hook(momentum_config) if e2e_training: self.register_optimizer_hook(optimizer_b_config, priority="HIGH", optim_type="OptimHookB") self.register_optimizer_hook(optimizer_g_config, priority="NORMAL", optim_type="OptimHookG") self.register_optimizer_hook(optimizer_d_config, priority="LOW", optim_type="OptimHookD") # self.register_optimizer_hook(optimizer_b_config, priority="NORMAL", optim_type="MultiOptimHook") self.register_checkpoint_hook(checkpoint_config) self.register_hook(IterTimerHook()) self.register_logger_hooks(log_config)
def register_training_hooks(self, lr_config, weight_optim_config=None, arch_optim_config=None, checkpoint_config=None, log_config=None): """Register default hooks for training. Default hooks include: - LrUpdaterHook - Weight/Arch_OptimizerStepperHook - CheckpointSaverHook - IterTimerHook - LoggerHook(s) """ if weight_optim_config is None: weight_optim_config = {} if arch_optim_config is None: arch_optim_config = {} if checkpoint_config is None: checkpoint_config = {} self.register_lr_hooks(lr_config) self.register_hook(self.build_hook(weight_optim_config, OptimizerHook)) self.register_hook(self.build_hook(checkpoint_config, CheckpointHook)) # self.register_hook(ModelInfoHook(self.cfg.model_info_interval), priority='VERY_LOW') self.register_hook(DropProcessHook(), priority='LOW') self.register_hook(IterTimerHook()) self.register_arch_hook( self.build_hook(arch_optim_config, ArchOptimizerHook)) self.register_arch_hook(ModelInfoHook(self.cfg.model_info_interval), priority='VERY_LOW') self.register_arch_hook(DropProcessHook(), priority='LOW') self.register_arch_hook(IterTimerHook()) if log_config is not None: self.register_logger_hooks( log_config) # logger_hook for arch_hook will be added inside
def register_training_hooks(self, lr_config, optimizer_config=None): """Register default hooks for training. Default hooks include: - LrUpdaterHook - OptimizerStepperHook - IterTimerHook """ if optimizer_config is None: optimizer_config = {} self.register_lr_hooks(lr_config) self.register_hook(self.build_hook(optimizer_config, OptimizerHook)) self.register_hook(IterTimerHook())
def test_cosine_cooldown_hook(multi_optimziers): """xdoctest -m tests/test_hooks.py test_cosine_runner_hook.""" loader = DataLoader(torch.ones((10, 2))) runner = _build_demo_runner(multi_optimziers=multi_optimziers) # add momentum LR scheduler hook_cfg = dict(type='CosineAnnealingCooldownLrUpdaterHook', by_epoch=False, cool_down_time=2, cool_down_ratio=0.1, min_lr_ratio=0.1, warmup_iters=2, warmup_ratio=0.9) runner.register_hook_from_cfg(hook_cfg) runner.register_hook_from_cfg(dict(type='IterTimerHook')) runner.register_hook(IterTimerHook()) if multi_optimziers: check_hook = ValueCheckHook({ 0: { 'current_lr()["model1"][0]': 0.02, 'current_lr()["model2"][0]': 0.01, }, 5: { 'current_lr()["model1"][0]': 0.0075558491, 'current_lr()["model2"][0]': 0.0037779246, }, 9: { 'current_lr()["model1"][0]': 0.0002, 'current_lr()["model2"][0]': 0.0001, } }) else: check_hook = ValueCheckHook({ 0: { 'current_lr()[0]': 0.02, }, 5: { 'current_lr()[0]': 0.0075558491, }, 9: { 'current_lr()[0]': 0.0002, } }) runner.register_hook(check_hook, priority='LOWEST') runner.run([loader], [('train', 1)]) shutil.rmtree(runner.work_dir)
def test_register_timer_hook(runner_class): model = Model() runner = runner_class(model=model, logger=logging.getLogger()) # test register None timer_config = None runner.register_timer_hook(timer_config) assert len(runner.hooks) == 0 # test register IterTimerHook with config timer_config = dict(type='IterTimerHook') runner.register_timer_hook(timer_config) assert len(runner.hooks) == 1 assert isinstance(runner.hooks[0], IterTimerHook) # test register IterTimerHook timer_config = IterTimerHook() runner.register_timer_hook(timer_config) assert len(runner.hooks) == 2 assert isinstance(runner.hooks[1], IterTimerHook)