def configure_optimizers(self): optimizer = { "sgd": FusedSGD(self.parameters(), lr=self.lr, momentum=self.args.momentum), "adam": FusedAdam(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "adamw": torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "radam": RAdam(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "adabelief": AdaBelief(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "adabound": AdaBound(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "adamp": AdamP(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "novograd": FusedNovoGrad(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), }[self.args.optimizer.lower()] if not self.args.use_scheduler: return optimizer scheduler = { "scheduler": NoamLR( optimizer=optimizer, warmup_epochs=self.args.warmup, total_epochs=self.args.epochs, steps_per_epoch=len(self.train_dataloader()) // self.args.gpus, init_lr=self.args.init_lr, max_lr=self.args.lr, final_lr=self.args.final_lr, ), "interval": "step", "frequency": 1, } return {"optimizer": optimizer, "lr_scheduler": scheduler}
def create_optimizer(args, model, filter_bias_and_bn=True, classification_layer_name=None): opt_lower = args.opt.lower() weight_decay = args.weight_decay if 'adamw' in opt_lower or 'radam' in opt_lower: # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay # I don't believe they follow the paper or original Torch7 impl which schedules weight # decay based on the ratio of current_lr/initial_lr weight_decay /= args.lr if weight_decay and filter_bias_and_bn: # batch norm and bias params if classification_layer_name is not None: parameters = set_lr_per_params(args, model, classification_layer_name, weight_decay) else: parameters = add_weight_decay(model, weight_decay) weight_decay = 0. # reset to 0 else: if classification_layer_name is not None: parameters = set_lr_per_params(args, model, classification_layer_name, weight_decay=0) else: parameters = model.parameters() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'momentum': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamw': optimizer = AdamW(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nadam': optimizer = Nadam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'radam': optimizer = RAdam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamp': optimizer = AdamP(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps, delta=0.1, wd_ratio=0.01, nesterov=True) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, eps=args.opt_eps, nesterov=True) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedsgd': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'fusedmomentum': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusednovograd': optimizer = FusedNovoGrad(parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def create_optimizer_param(args, parameters): opt_lower = args.opt.lower() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=args.lr, weight_decay=args.weight_decay) if hasattr(args, 'opt_eps') and args.opt_eps is not None: opt_args['eps'] = args.opt_eps if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas if hasattr(args, 'opt_args') and args.opt_args is not None: opt_args.update(args.opt_args) opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'nadam': optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adafactor': if not args.lr: opt_args['lr'] = None optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay if 'adamw' in opt_lower or 'radam' in opt_lower: # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay # I don't believe they follow the paper or original Torch7 impl which schedules weight # decay based on the ratio of current_lr/initial_lr weight_decay /= args.lr if weight_decay and filter_bias_and_bn: print("has weight decay and filter bias") parameters = add_weight_decay(model, weight_decay) weight_decay = 0. else: print("Comes here to unfrozen params inside optim") parameters = unfrozen_params(model) if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'momentum': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamw': optimizer = AdamW(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nadam': optimizer = Nadam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'radam': optimizer = RAdam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedsgd': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'fusedmomentum': print("my optimizer") optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusednovograd': optimizer = FusedNovoGrad(parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def create_optimizer_v2(model_or_params, opt: str = 'sgd', lr: Optional[float] = None, weight_decay: float = 0., momentum: float = 0.9, filter_bias_and_bn: bool = True, layer_decay: Optional[float] = None, param_group_fn: Optional[Callable] = None, **kwargs): """ Create an optimizer. TODO currently the model is passed in and all parameters are selected for optimization. For more general use an interface that allows selection of parameters to optimize and lr groups, one of: * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion * expose the parameters interface and leave it up to caller Args: model_or_params (nn.Module): model containing parameters to optimize opt: name of optimizer to create lr: initial learning rate weight_decay: weight decay to apply in optimizer momentum: momentum for momentum based optimizers (others may use betas via kwargs) filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay **kwargs: extra optimizer specific kwargs to pass through Returns: Optimizer """ if isinstance(model_or_params, nn.Module): # a model was passed in, extract parameters and add weight decays to appropriate layers no_weight_decay = {} if hasattr(model_or_params, 'no_weight_decay'): no_weight_decay = model_or_params.no_weight_decay() if param_group_fn: parameters = param_group_fn(model_or_params) elif layer_decay is not None: parameters = param_groups_layer_decay( model_or_params, weight_decay=weight_decay, layer_decay=layer_decay, no_weight_decay_list=no_weight_decay) weight_decay = 0. elif weight_decay and filter_bias_and_bn: parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) weight_decay = 0. else: parameters = model_or_params.parameters() else: # iterable of parameters or param groups passed in parameters = model_or_params opt_lower = opt.lower() opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_args = dict(weight_decay=weight_decay, **kwargs) if lr is not None: opt_args.setdefault('lr', lr) # basic SGD & related if opt_lower == 'sgd' or opt_lower == 'nesterov': # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) # adaptive elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'nadam': try: # NOTE PyTorch >= 1.10 should have native NAdam optimizer = optim.Nadam(parameters, **opt_args) except AttributeError: optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamax': optimizer = optim.Adamax(parameters, **opt_args) elif opt_lower == 'adabelief': optimizer = AdaBelief(parameters, rectify=False, **opt_args) elif opt_lower == 'radabelief': optimizer = AdaBelief(parameters, rectify=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adagrad': opt_args.setdefault('eps', 1e-8) optimizer = optim.Adagrad(parameters, **opt_args) elif opt_lower == 'adafactor': optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'lamb': optimizer = Lamb(parameters, **opt_args) elif opt_lower == 'lambc': optimizer = Lamb(parameters, trust_clip=True, **opt_args) elif opt_lower == 'larc': optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args) elif opt_lower == 'lars': optimizer = Lars(parameters, momentum=momentum, **opt_args) elif opt_lower == 'nlarc': optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args) elif opt_lower == 'nlars': optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'madgrad': optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) elif opt_lower == 'madgradw': optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args) elif opt_lower == 'novograd' or opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) # second order elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) # NVIDIA fused optimizers, require APEX to be installed elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def create_optimizer_v2( model: nn.Module, optimizer_name: str = 'sgd', learning_rate: Optional[float] = None, weight_decay: float = 0., momentum: float = 0.9, filter_bias_and_bn: bool = True, **kwargs): """ Create an optimizer. TODO currently the model is passed in and all parameters are selected for optimization. For more general use an interface that allows selection of parameters to optimize and lr groups, one of: * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion * expose the parameters interface and leave it up to caller Args: model (nn.Module): model containing parameters to optimize optimizer_name: name of optimizer to create learning_rate: initial learning rate weight_decay: weight decay to apply in optimizer momentum: momentum for momentum based optimizers (others may use betas via kwargs) filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay **kwargs: extra optimizer specific kwargs to pass through Returns: Optimizer """ opt_lower = optimizer_name.lower() if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay() parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: parameters = model.parameters() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs) opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adabelief': optimizer = AdaBelief(parameters, rectify = False, print_change_log = False,**opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'nadam': optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adafactor': if not learning_rate: opt_args['lr'] = None optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def create_optimizer(args, model, filter_bias_and_bn=True, freeze_stage=""): opt_lower = args.opt.lower() weight_decay = args.weight_decay if 'adamw' in opt_lower or 'radam' in opt_lower: # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay # I don't believe they follow the paper or original Torch7 impl which schedules weight # decay based on the ratio of current_lr/initial_lr weight_decay /= args.lr if weight_decay and filter_bias_and_bn: if freeze_stage == "stage1": stage1_train_attn(model, layer_names=['fc']) print('stage1, Freeze layer successfully') if freeze_stage == "stage2": stage1_train_attn(model, layer_names=['layer3', 'layer4', 'se', 'fc']) stage2_train_layer4(model) print('stage2, Freeze layer successfully') # 对未冻结的层进行权重衰减 parameters = add_weight_decay(model, weight_decay) weight_decay = 0. else: parameters = model.parameters() for name, param in model.named_parameters(): print(name, param.requires_grad) if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'momentum': optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamw': optimizer = AdamW(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nadam': optimizer = Nadam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'radam': optimizer = RAdam(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adamp': optimizer = AdamP(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps, delta=0.1, wd_ratio=0.01, nesterov=True) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, eps=args.opt_eps, nesterov=True) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedsgd': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'fusedmomentum': optimizer = FusedSGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'fusednovograd': optimizer = FusedNovoGrad(parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def main(): args = parse_args() assert (torch.cuda.is_available()) assert args.prediction_frequency % args.log_frequency == 0 torch.backends.cudnn.benchmark = args.cudnn_benchmark # set up distributed training multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1 if multi_gpu: torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') world_size = dist.get_world_size() print_once(f'Distributed training with {world_size} GPUs\n') else: world_size = 1 torch.manual_seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) random.seed(args.seed + args.local_rank) init_log(args) cfg = config.load(args.model_config) config.apply_config_overrides(cfg, args) symbols = helpers.add_ctc_blank(cfg['labels']) assert args.grad_accumulation >= 1 batch_size = args.gpu_batch_size print_once('Setting up datasets...') train_dataset_kw, train_features_kw = config.input(cfg, 'train') val_dataset_kw, val_features_kw = config.input(cfg, 'val') use_dali = args.dali_device in ('cpu', 'gpu') if use_dali: assert train_dataset_kw['ignore_offline_speed_perturbation'], \ "DALI doesn't support offline speed perturbation" # pad_to_max_duration is not supported by DALI - have simple padders if train_features_kw['pad_to_max_duration']: train_feat_proc = BaseFeatures( pad_align=train_features_kw['pad_align'], pad_to_max_duration=True, max_duration=train_features_kw['max_duration'], sample_rate=train_features_kw['sample_rate'], window_size=train_features_kw['window_size'], window_stride=train_features_kw['window_stride']) train_features_kw['pad_to_max_duration'] = False else: train_feat_proc = None if val_features_kw['pad_to_max_duration']: val_feat_proc = BaseFeatures( pad_align=val_features_kw['pad_align'], pad_to_max_duration=True, max_duration=val_features_kw['max_duration'], sample_rate=val_features_kw['sample_rate'], window_size=val_features_kw['window_size'], window_stride=val_features_kw['window_stride']) val_features_kw['pad_to_max_duration'] = False else: val_feat_proc = None train_loader = DaliDataLoader( gpu_id=args.local_rank, dataset_path=args.dataset_dir, config_data=train_dataset_kw, config_features=train_features_kw, json_names=args.train_manifests, batch_size=batch_size, grad_accumulation_steps=args.grad_accumulation, pipeline_type="train", device_type=args.dali_device, symbols=symbols) val_loader = DaliDataLoader(gpu_id=args.local_rank, dataset_path=args.dataset_dir, config_data=val_dataset_kw, config_features=val_features_kw, json_names=args.val_manifests, batch_size=batch_size, pipeline_type="val", device_type=args.dali_device, symbols=symbols) else: train_dataset_kw, train_features_kw = config.input(cfg, 'train') train_dataset = AudioDataset(args.dataset_dir, args.train_manifests, symbols, **train_dataset_kw) train_loader = get_data_loader(train_dataset, batch_size, multi_gpu=multi_gpu, shuffle=True, num_workers=4) train_feat_proc = FilterbankFeatures(**train_features_kw) val_dataset_kw, val_features_kw = config.input(cfg, 'val') val_dataset = AudioDataset(args.dataset_dir, args.val_manifests, symbols, **val_dataset_kw) val_loader = get_data_loader(val_dataset, batch_size, multi_gpu=multi_gpu, shuffle=False, num_workers=4, drop_last=False) val_feat_proc = FilterbankFeatures(**val_features_kw) dur = train_dataset.duration / 3600 dur_f = train_dataset.duration_filtered / 3600 nsampl = len(train_dataset) print_once(f'Training samples: {nsampl} ({dur:.1f}h, ' f'filtered {dur_f:.1f}h)') if train_feat_proc is not None: train_feat_proc.cuda() if val_feat_proc is not None: val_feat_proc.cuda() steps_per_epoch = len(train_loader) // args.grad_accumulation # set up the model model = QuartzNet(encoder_kw=config.encoder(cfg), decoder_kw=config.decoder(cfg, n_classes=len(symbols))) model.cuda() ctc_loss = CTCLossNM(n_classes=len(symbols)) greedy_decoder = GreedyCTCDecoder() print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n') # optimization kw = {'lr': args.lr, 'weight_decay': args.weight_decay} if args.optimizer == "novograd": optimizer = Novograd(model.parameters(), **kw) elif args.optimizer == "adamw": optimizer = AdamW(model.parameters(), **kw) elif args.optimizer == 'lamb98': optimizer = FusedLAMB(model.parameters(), betas=(0.9, 0.98), eps=1e-9, **kw) elif args.optimizer == 'fused_novograd': optimizer = FusedNovoGrad(model.parameters(), betas=(0.95, 0), bias_correction=False, reg_inside_moment=True, grad_averaging=False, **kw) else: raise ValueError(f'Invalid optimizer "{args.optimizer}"') scaler = torch.cuda.amp.GradScaler(enabled=args.amp) adjust_lr = lambda step, epoch, optimizer: lr_policy( step, epoch, args.lr, optimizer, steps_per_epoch=steps_per_epoch, warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs, num_epochs=args.epochs, policy=args.lr_policy, min_lr=args.min_lr, exp_gamma=args.lr_exp_gamma) if args.ema > 0: ema_model = copy.deepcopy(model) else: ema_model = None if multi_gpu: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) if args.pyprof: pyprof.init(enable_function_stack=True) # load checkpoint meta = {'best_wer': 10**6, 'start_epoch': 0} checkpointer = Checkpointer(args.output_dir, 'QuartzNet', args.keep_milestones) if args.resume: args.ckpt = checkpointer.last_checkpoint() or args.ckpt if args.ckpt is not None: checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta) start_epoch = meta['start_epoch'] best_wer = meta['best_wer'] epoch = 1 step = start_epoch * steps_per_epoch + 1 if args.pyprof: torch.autograd.profiler.emit_nvtx().__enter__() profiler.start() # training loop model.train() if args.ema > 0.0: mt_ema_params = init_multi_tensor_ema(model, ema_model) # ema_model_weight_list, model_weight_list, overflow_buf_for_ema = ema_ # pre-allocate if args.pre_allocate_range is not None: n_feats = train_features_kw['n_filt'] pad_align = train_features_kw['pad_align'] a, b = args.pre_allocate_range for n_frames in range(a, b + pad_align, pad_align): print_once( f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...') feat = torch.randn(batch_size, n_feats, n_frames, device='cuda') feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames) txt = torch.randint(high=len(symbols) - 1, size=(batch_size, 100), device='cuda') txt_lens = torch.ones(batch_size, device='cuda').fill_(100) with torch.cuda.amp.autocast(enabled=args.amp): log_probs, enc_lens = model(feat, feat_lens) del feat loss = ctc_loss(log_probs, txt, enc_lens, txt_lens) loss.backward() model.zero_grad() torch.cuda.empty_cache() bmark_stats = BenchmarkStats() for epoch in range(start_epoch + 1, args.epochs + 1): if multi_gpu and not use_dali: train_loader.sampler.set_epoch(epoch) epoch_utts = 0 epoch_loss = 0 accumulated_batches = 0 epoch_start_time = time.time() epoch_eval_time = 0 for batch in train_loader: if accumulated_batches == 0: step_loss = 0 step_utts = 0 step_start_time = time.time() if use_dali: # with DALI, the data is already on GPU feat, feat_lens, txt, txt_lens = batch if train_feat_proc is not None: feat, feat_lens = train_feat_proc(feat, feat_lens) else: batch = [t.cuda(non_blocking=True) for t in batch] audio, audio_lens, txt, txt_lens = batch feat, feat_lens = train_feat_proc(audio, audio_lens) # Use context manager to prevent redundant accumulation of gradients if (multi_gpu and accumulated_batches + 1 < args.grad_accumulation): ctx = model.no_sync() else: ctx = empty_context() with ctx: with torch.cuda.amp.autocast(enabled=args.amp): log_probs, enc_lens = model(feat, feat_lens) loss = ctc_loss(log_probs, txt, enc_lens, txt_lens) loss /= args.grad_accumulation if multi_gpu: reduced_loss = reduce_tensor(loss.data, world_size) else: reduced_loss = loss if torch.isnan(reduced_loss).any(): print_once(f'WARNING: loss is NaN; skipping update') continue else: step_loss += reduced_loss.item() step_utts += batch[0].size(0) * world_size epoch_utts += batch[0].size(0) * world_size accumulated_batches += 1 scaler.scale(loss).backward() if accumulated_batches % args.grad_accumulation == 0: epoch_loss += step_loss scaler.step(optimizer) scaler.update() adjust_lr(step, epoch, optimizer) optimizer.zero_grad() if args.ema > 0.0: apply_multi_tensor_ema(args.ema, *mt_ema_params) if step % args.log_frequency == 0: preds = greedy_decoder(log_probs) wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens, symbols) if step % args.prediction_frequency == 0: print_once(f' Decoded: {pred_utt[:90]}') print_once(f' Reference: {ref[:90]}') step_time = time.time() - step_start_time log( (epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch), step, 'train', { 'loss': step_loss, 'wer': 100.0 * wer, 'throughput': step_utts / step_time, 'took': step_time, 'lrate': optimizer.param_groups[0]['lr'] }) step_start_time = time.time() if step % args.eval_frequency == 0: tik = time.time() wer = evaluate(epoch, step, val_loader, val_feat_proc, symbols, model, ema_model, ctc_loss, greedy_decoder, args.amp, use_dali) if wer < best_wer and epoch >= args.save_best_from: checkpointer.save(model, ema_model, optimizer, scaler, epoch, step, best_wer, is_best=True) best_wer = wer epoch_eval_time += time.time() - tik step += 1 accumulated_batches = 0 # end of step # DALI iterator need to be exhausted; # if not using DALI, simulate drop_last=True with grad accumulation if not use_dali and step > steps_per_epoch * epoch: break epoch_time = time.time() - epoch_start_time epoch_loss /= steps_per_epoch log( (epoch, ), None, 'train_avg', { 'throughput': epoch_utts / epoch_time, 'took': epoch_time, 'loss': epoch_loss }) bmark_stats.update(epoch_utts, epoch_time, epoch_loss) if epoch % args.save_frequency == 0 or epoch in args.keep_milestones: checkpointer.save(model, ema_model, optimizer, scaler, epoch, step, best_wer) if 0 < args.epochs_this_job <= epoch - start_epoch: print_once(f'Finished after {args.epochs_this_job} epochs.') break # end of epoch if args.pyprof: profiler.stop() torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num)) if epoch == args.epochs: evaluate(epoch, step, val_loader, val_feat_proc, symbols, model, ema_model, ctc_loss, greedy_decoder, args.amp, use_dali) checkpointer.save(model, ema_model, optimizer, scaler, epoch, step, best_wer) flush_log()
def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: parameters = model.parameters() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_split = opt_lower.split('_') opt_lower = opt_split[-1] opt_args = dict(lr=args.lr, weight_decay=weight_decay) opt_args = dict(lr=args.lr, weight_decay=weight_decay) if hasattr(args, 'opt_eps') and args.opt_eps is not None and opt_lower not in [ 'sgd', 'momentum', 'fusedmomentum', 'fusedsgd' ]: opt_args['eps'] = args.opt_eps if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas if opt_lower == 'sgd' or opt_lower == 'nesterov': optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'fusedsgd': optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay() parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: parameters = model.parameters() if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=args.lr, weight_decay=weight_decay) if hasattr(args, 'opt_eps') and args.opt_eps is not None: opt_args['eps'] = args.opt_eps if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas if hasattr(args, 'opt_args') and args.opt_args is not None: opt_args.update(args.opt_args) opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'nadam': optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamp': # ================================ # optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) print(' ') print('Gradient centralization is enabled for AdamP optimizer.') print(' ') optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, use_gc=True, gc_conv_only=True, gc_loc=False, **opt_args) # ================================ elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adafactor': if not args.lr: opt_args['lr'] = None optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer
def create_optimizer(optimizer_config, model, master_params=None): """Creates optimizer and schedule from configuration Parameters ---------- optimizer_config : dict Dictionary containing the configuration options for the optimizer. model : Model The network model. Returns ------- optimizer : Optimizer The optimizer. scheduler : LRScheduler The learning rate scheduler. """ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', "_bn0.weight", "_bn1.weight", "_bn2.weight"] def make_params(param_optimizer, lr=None): params = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': optimizer_config["weight_decay"]}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] for p in params: if lr is not None: p["lr"] = lr return params if optimizer_config.get("classifier_lr", -1) != -1: # Separate classifier parameters from all others net_params = [] classifier_params = [] for k, v in model.named_parameters(): if not v.requires_grad: continue if k.find("encoder") != -1: net_params.append((k, v)) else: classifier_params.append((k, v)) params = [] params.extend(make_params(classifier_params, optimizer_config["classifier_lr"])) params.extend(make_params(net_params)) print("param_groups", len(params)) else: param_optimizer = list(model.named_parameters()) params = make_params(param_optimizer) print("param_groups", len(params)) if optimizer_config["type"] == "SGD": optimizer = optim.SGD(params, lr=optimizer_config["learning_rate"], momentum=optimizer_config["momentum"], nesterov=optimizer_config["nesterov"]) elif optimizer_config["type"] == "Adam": optimizer = optim.Adam(params, eps=optimizer_config.get("eps", 1e-8), lr=optimizer_config["learning_rate"], weight_decay=optimizer_config["weight_decay"]) elif optimizer_config["type"] == "FusedAdam": optimizer = FusedAdam(params, eps=optimizer_config.get("eps", 1e-8), lr=optimizer_config["learning_rate"], weight_decay=optimizer_config["weight_decay"]) elif optimizer_config["type"] == "FusedNovoGrad": optimizer = FusedNovoGrad(params, eps=optimizer_config.get("eps", 1e-8), lr=optimizer_config["learning_rate"], weight_decay=optimizer_config["weight_decay"]) elif optimizer_config["type"] == "AdamW": optimizer = AdamW(params, eps=optimizer_config.get("eps", 1e-8), lr=optimizer_config["learning_rate"], weight_decay=optimizer_config["weight_decay"]) elif optimizer_config["type"] == "RmsProp": optimizer = RMSprop(params, lr=optimizer_config["learning_rate"], weight_decay=optimizer_config["weight_decay"]) else: raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"])) if optimizer_config["schedule"]["type"] == "step": scheduler = LRStepScheduler(optimizer, **optimizer_config["schedule"]["params"]) elif optimizer_config["schedule"]["type"] == "cosine": scheduler = CosineAnnealingLR(optimizer, **optimizer_config["schedule"]["params"]) elif optimizer_config["schedule"]["type"] == "clr": scheduler = CyclicLR(optimizer, **optimizer_config["schedule"]["params"]) elif optimizer_config["schedule"]["type"] == "multistep": scheduler = MultiStepLR(optimizer, **optimizer_config["schedule"]["params"]) elif optimizer_config["schedule"]["type"] == "exponential": scheduler = ExponentialLRScheduler(optimizer, **optimizer_config["schedule"]["params"]) elif optimizer_config["schedule"]["type"] == "poly": scheduler = PolyLR(optimizer, **optimizer_config["schedule"]["params"]) elif optimizer_config["schedule"]["type"] == "constant": scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0) elif optimizer_config["schedule"]["type"] == "linear": def linear_lr(it): return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"] scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr) return optimizer, scheduler