def build_optimizer(cls, cfg: CfgNode, model: nn.Module): params = get_default_optimizer_params( model, base_lr=cfg.SOLVER.BASE_LR, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, overrides={ "features": { "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR, }, "embeddings": { "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR, }, }, ) optimizer = torch.optim.SGD( params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV, weight_decay=cfg.SOLVER.WEIGHT_DECAY, ) return maybe_add_gradient_clipping(cfg, optimizer)
def build_optimizer(cls, cfg, model): params = get_default_optimizer_params( model, base_lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, ) def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class # detectron2 doesn't have full model gradient clipping now clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE enable = ( cfg.SOLVER.CLIP_GRADIENTS.ENABLED and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" and clip_norm_val > 0.0 ) class FullModelGradientClippingOptimizer(optim): def step(self, closure=None): all_params = itertools.chain(*[x["params"] for x in self.param_groups]) torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) super().step(closure=closure) return FullModelGradientClippingOptimizer if enable else optim optimizer_type = cfg.SOLVER.OPTIMIZER if optimizer_type == "SGD": optimizer = maybe_add_gradient_clipping(torch.optim.SGD)( params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV, weight_decay=cfg.SOLVER.WEIGHT_DECAY, ) elif optimizer_type == "AdamW": optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( params, cfg.SOLVER.BASE_LR, betas=(0.9, 0.999), weight_decay=cfg.SOLVER.WEIGHT_DECAY, ) else: raise NotImplementedError(f"no optimizer type {optimizer_type}") return optimizer