예제 #1
0
    def configure_optimizers(self):
        optimizer = {
            "sgd":
            FusedSGD(self.parameters(),
                     lr=self.lr,
                     momentum=self.args.momentum),
            "adam":
            FusedAdam(self.parameters(),
                      lr=self.lr,
                      weight_decay=self.args.weight_decay),
            "adamw":
            torch.optim.AdamW(self.parameters(),
                              lr=self.lr,
                              weight_decay=self.args.weight_decay),
            "radam":
            RAdam(self.parameters(),
                  lr=self.lr,
                  weight_decay=self.args.weight_decay),
            "adabelief":
            AdaBelief(self.parameters(),
                      lr=self.lr,
                      weight_decay=self.args.weight_decay),
            "adabound":
            AdaBound(self.parameters(),
                     lr=self.lr,
                     weight_decay=self.args.weight_decay),
            "adamp":
            AdamP(self.parameters(),
                  lr=self.lr,
                  weight_decay=self.args.weight_decay),
            "novograd":
            FusedNovoGrad(self.parameters(),
                          lr=self.lr,
                          weight_decay=self.args.weight_decay),
        }[self.args.optimizer.lower()]

        if not self.args.use_scheduler:
            return optimizer

        scheduler = {
            "scheduler":
            NoamLR(
                optimizer=optimizer,
                warmup_epochs=self.args.warmup,
                total_epochs=self.args.epochs,
                steps_per_epoch=len(self.train_dataloader()) // self.args.gpus,
                init_lr=self.args.init_lr,
                max_lr=self.args.lr,
                final_lr=self.args.final_lr,
            ),
            "interval":
            "step",
            "frequency":
            1,
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler}
def create_optimizer(args,
                     model,
                     filter_bias_and_bn=True,
                     classification_layer_name=None):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if 'adamw' in opt_lower or 'radam' in opt_lower:
        # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
        # I don't believe they follow the paper or original Torch7 impl which schedules weight
        # decay based on the ratio of current_lr/initial_lr
        weight_decay /= args.lr

    if weight_decay and filter_bias_and_bn:  # batch norm and bias params
        if classification_layer_name is not None:
            parameters = set_lr_per_params(args, model,
                                           classification_layer_name,
                                           weight_decay)
        else:
            parameters = add_weight_decay(model, weight_decay)
        weight_decay = 0.  # reset to 0
    else:
        if classification_layer_name is not None:
            parameters = set_lr_per_params(args,
                                           model,
                                           classification_layer_name,
                                           weight_decay=0)
        else:
            parameters = model.parameters()

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=True)
    elif opt_lower == 'momentum':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=False)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'adamw':
        optimizer = AdamW(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'adamp':
        optimizer = AdamP(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps,
                          delta=0.1,
                          wd_ratio=0.01,
                          nesterov=True)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=weight_decay,
                         eps=args.opt_eps,
                         nesterov=True)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters,
                                   lr=args.lr,
                                   weight_decay=weight_decay,
                                   eps=args.opt_eps)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  lr=args.lr,
                                  alpha=0.9,
                                  eps=args.opt_eps,
                                  momentum=args.momentum,
                                  weight_decay=weight_decay)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              lr=args.lr,
                              alpha=0.9,
                              eps=args.opt_eps,
                              momentum=args.momentum,
                              weight_decay=weight_decay)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters,
                             lr=args.lr,
                             weight_decay=weight_decay,
                             eps=args.opt_eps)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'fusedsgd':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=True)
    elif opt_lower == 'fusedmomentum':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=False)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=False,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=True,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters,
                              lr=args.lr,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusednovograd':
        optimizer = FusedNovoGrad(parameters,
                                  lr=args.lr,
                                  betas=(0.95, 0.98),
                                  weight_decay=weight_decay,
                                  eps=args.opt_eps)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
def create_optimizer_param(args, parameters):
    opt_lower = args.opt.lower()

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(lr=args.lr, weight_decay=args.weight_decay)
    if hasattr(args, 'opt_eps') and args.opt_eps is not None:
        opt_args['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args['betas'] = args.opt_betas
    if hasattr(args, 'opt_args') and args.opt_args is not None:
        opt_args.update(args.opt_args)

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=True,
                              **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=False,
                              **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'adamp':
        optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         momentum=args.momentum,
                         nesterov=True,
                         **opt_args)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == 'adafactor':
        if not args.lr:
            opt_args['lr'] = None
        optimizer = Adafactor(parameters, **opt_args)
    elif opt_lower == 'adahessian':
        optimizer = Adahessian(parameters, **opt_args)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  alpha=0.9,
                                  momentum=args.momentum,
                                  **opt_args)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              alpha=0.9,
                              momentum=args.momentum,
                              **opt_args)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters, **opt_args)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, **opt_args)
    elif opt_lower == 'fusedsgd':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=True,
                             **opt_args)
    elif opt_lower == 'fusedmomentum':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=False,
                             **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
예제 #4
0
def create_optimizer(args, model, filter_bias_and_bn=True):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if 'adamw' in opt_lower or 'radam' in opt_lower:
        # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
        # I don't believe they follow the paper or original Torch7 impl which schedules weight
        # decay based on the ratio of current_lr/initial_lr
        weight_decay /= args.lr
    if weight_decay and filter_bias_and_bn:
        print("has weight decay and filter bias")
        parameters = add_weight_decay(model, weight_decay)
        weight_decay = 0.
    else:
        print("Comes here to unfrozen params inside optim")

        parameters = unfrozen_params(model)

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=True)
    elif opt_lower == 'momentum':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=False)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'adamw':
        optimizer = AdamW(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters,
                                   lr=args.lr,
                                   weight_decay=weight_decay,
                                   eps=args.opt_eps)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  lr=args.lr,
                                  alpha=0.9,
                                  eps=args.opt_eps,
                                  momentum=args.momentum,
                                  weight_decay=weight_decay)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              lr=args.lr,
                              alpha=0.9,
                              eps=args.opt_eps,
                              momentum=args.momentum,
                              weight_decay=weight_decay)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters,
                             lr=args.lr,
                             weight_decay=weight_decay,
                             eps=args.opt_eps)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'fusedsgd':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=True)
    elif opt_lower == 'fusedmomentum':
        print("my optimizer")
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=False)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=False,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=True,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters,
                              lr=args.lr,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusednovograd':
        optimizer = FusedNovoGrad(parameters,
                                  lr=args.lr,
                                  betas=(0.95, 0.98),
                                  weight_decay=weight_decay,
                                  eps=args.opt_eps)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
예제 #5
0
def create_optimizer_v2(model_or_params,
                        opt: str = 'sgd',
                        lr: Optional[float] = None,
                        weight_decay: float = 0.,
                        momentum: float = 0.9,
                        filter_bias_and_bn: bool = True,
                        layer_decay: Optional[float] = None,
                        param_group_fn: Optional[Callable] = None,
                        **kwargs):
    """ Create an optimizer.

    TODO currently the model is passed in and all parameters are selected for optimization.
    For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
      * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
      * expose the parameters interface and leave it up to caller

    Args:
        model_or_params (nn.Module): model containing parameters to optimize
        opt: name of optimizer to create
        lr: initial learning rate
        weight_decay: weight decay to apply in optimizer
        momentum:  momentum for momentum based optimizers (others may use betas via kwargs)
        filter_bias_and_bn:  filter out bias, bn and other 1d params from weight decay
        **kwargs: extra optimizer specific kwargs to pass through

    Returns:
        Optimizer
    """
    if isinstance(model_or_params, nn.Module):
        # a model was passed in, extract parameters and add weight decays to appropriate layers
        no_weight_decay = {}
        if hasattr(model_or_params, 'no_weight_decay'):
            no_weight_decay = model_or_params.no_weight_decay()

        if param_group_fn:
            parameters = param_group_fn(model_or_params)
        elif layer_decay is not None:
            parameters = param_groups_layer_decay(
                model_or_params,
                weight_decay=weight_decay,
                layer_decay=layer_decay,
                no_weight_decay_list=no_weight_decay)
            weight_decay = 0.
        elif weight_decay and filter_bias_and_bn:
            parameters = param_groups_weight_decay(model_or_params,
                                                   weight_decay,
                                                   no_weight_decay)
            weight_decay = 0.
        else:
            parameters = model_or_params.parameters()
    else:
        # iterable of parameters or param groups passed in
        parameters = model_or_params

    opt_lower = opt.lower()
    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(weight_decay=weight_decay, **kwargs)
    if lr is not None:
        opt_args.setdefault('lr', lr)

    # basic SGD & related
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=momentum,
                              nesterov=True,
                              **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=momentum,
                              nesterov=False,
                              **opt_args)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         momentum=momentum,
                         nesterov=True,
                         **opt_args)

    # adaptive
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'adamp':
        optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
    elif opt_lower == 'nadam':
        try:
            # NOTE PyTorch >= 1.10 should have native NAdam
            optimizer = optim.Nadam(parameters, **opt_args)
        except AttributeError:
            optimizer = Nadam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'adamax':
        optimizer = optim.Adamax(parameters, **opt_args)
    elif opt_lower == 'adabelief':
        optimizer = AdaBelief(parameters, rectify=False, **opt_args)
    elif opt_lower == 'radabelief':
        optimizer = AdaBelief(parameters, rectify=True, **opt_args)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == 'adagrad':
        opt_args.setdefault('eps', 1e-8)
        optimizer = optim.Adagrad(parameters, **opt_args)
    elif opt_lower == 'adafactor':
        optimizer = Adafactor(parameters, **opt_args)
    elif opt_lower == 'lamb':
        optimizer = Lamb(parameters, **opt_args)
    elif opt_lower == 'lambc':
        optimizer = Lamb(parameters, trust_clip=True, **opt_args)
    elif opt_lower == 'larc':
        optimizer = Lars(parameters,
                         momentum=momentum,
                         trust_clip=True,
                         **opt_args)
    elif opt_lower == 'lars':
        optimizer = Lars(parameters, momentum=momentum, **opt_args)
    elif opt_lower == 'nlarc':
        optimizer = Lars(parameters,
                         momentum=momentum,
                         trust_clip=True,
                         nesterov=True,
                         **opt_args)
    elif opt_lower == 'nlars':
        optimizer = Lars(parameters,
                         momentum=momentum,
                         nesterov=True,
                         **opt_args)
    elif opt_lower == 'madgrad':
        optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
    elif opt_lower == 'madgradw':
        optimizer = MADGRAD(parameters,
                            momentum=momentum,
                            decoupled_decay=True,
                            **opt_args)
    elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, **opt_args)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  alpha=0.9,
                                  momentum=momentum,
                                  **opt_args)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              alpha=0.9,
                              momentum=momentum,
                              **opt_args)

    # second order
    elif opt_lower == 'adahessian':
        optimizer = Adahessian(parameters, **opt_args)

    # NVIDIA fused optimizers, require APEX to be installed
    elif opt_lower == 'fusedsgd':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=momentum,
                             nesterov=True,
                             **opt_args)
    elif opt_lower == 'fusedmomentum':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=momentum,
                             nesterov=False,
                             **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)

    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
예제 #6
0
def create_optimizer_v2(
        model: nn.Module,
        optimizer_name: str = 'sgd',
        learning_rate: Optional[float] = None,
        weight_decay: float = 0.,
        momentum: float = 0.9,
        filter_bias_and_bn: bool = True,
        **kwargs):
    """ Create an optimizer.

    TODO currently the model is passed in and all parameters are selected for optimization.
    For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
      * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
      * expose the parameters interface and leave it up to caller

    Args:
        model (nn.Module): model containing parameters to optimize
        optimizer_name: name of optimizer to create
        learning_rate: initial learning rate
        weight_decay: weight decay to apply in optimizer
        momentum:  momentum for momentum based optimizers (others may use betas via kwargs)
        filter_bias_and_bn:  filter out bias, bn and other 1d params from weight decay
        **kwargs: extra optimizer specific kwargs to pass through

    Returns:
        Optimizer
    """
    opt_lower = optimizer_name.lower()
    if weight_decay and filter_bias_and_bn:
        skip = {}
        if hasattr(model, 'no_weight_decay'):
            skip = model.no_weight_decay()
        parameters = add_weight_decay(model, weight_decay, skip)
        weight_decay = 0.
    else:
        parameters = model.parameters()
    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs)
    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args) 
    elif opt_lower == 'adabelief':
        optimizer = AdaBelief(parameters, rectify = False, print_change_log = False,**opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'adamp':        
        optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == 'adafactor':
        if not learning_rate:
            opt_args['lr'] = None
        optimizer = Adafactor(parameters, **opt_args)
    elif opt_lower == 'adahessian':
        optimizer = Adahessian(parameters, **opt_args)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters, **opt_args)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, **opt_args)
    elif opt_lower == 'fusedsgd':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
    elif opt_lower == 'fusedmomentum':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
예제 #7
0
def create_optimizer(args, model, filter_bias_and_bn=True, freeze_stage=""):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if 'adamw' in opt_lower or 'radam' in opt_lower:
        # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
        # I don't believe they follow the paper or original Torch7 impl which schedules weight
        # decay based on the ratio of current_lr/initial_lr
        weight_decay /= args.lr
    if weight_decay and filter_bias_and_bn:
        if freeze_stage == "stage1":
            stage1_train_attn(model, layer_names=['fc'])
            print('stage1, Freeze layer successfully')
        if freeze_stage == "stage2":
            stage1_train_attn(model,
                              layer_names=['layer3', 'layer4', 'se', 'fc'])
            stage2_train_layer4(model)
            print('stage2, Freeze layer successfully')
        # 对未冻结的层进行权重衰减
        parameters = add_weight_decay(model, weight_decay)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    for name, param in model.named_parameters():
        print(name, param.requires_grad)

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=True)
    elif opt_lower == 'momentum':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=False)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'adamw':
        optimizer = AdamW(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'adamp':
        optimizer = AdamP(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps,
                          delta=0.1,
                          wd_ratio=0.01,
                          nesterov=True)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=weight_decay,
                         eps=args.opt_eps,
                         nesterov=True)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters,
                                   lr=args.lr,
                                   weight_decay=weight_decay,
                                   eps=args.opt_eps)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  lr=args.lr,
                                  alpha=0.9,
                                  eps=args.opt_eps,
                                  momentum=args.momentum,
                                  weight_decay=weight_decay)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              lr=args.lr,
                              alpha=0.9,
                              eps=args.opt_eps,
                              momentum=args.momentum,
                              weight_decay=weight_decay)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters,
                             lr=args.lr,
                             weight_decay=weight_decay,
                             eps=args.opt_eps)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'fusedsgd':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=True)
    elif opt_lower == 'fusedmomentum':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=False)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=False,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=True,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters,
                              lr=args.lr,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusednovograd':
        optimizer = FusedNovoGrad(parameters,
                                  lr=args.lr,
                                  betas=(0.95, 0.98),
                                  weight_decay=weight_decay,
                                  eps=args.opt_eps)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
예제 #8
0
def main():
    args = parse_args()

    assert (torch.cuda.is_available())
    assert args.prediction_frequency % args.log_frequency == 0

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    # set up distributed training
    multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')
        world_size = dist.get_world_size()
        print_once(f'Distributed training with {world_size} GPUs\n')
    else:
        world_size = 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    random.seed(args.seed + args.local_rank)

    init_log(args)

    cfg = config.load(args.model_config)
    config.apply_config_overrides(cfg, args)

    symbols = helpers.add_ctc_blank(cfg['labels'])

    assert args.grad_accumulation >= 1
    batch_size = args.gpu_batch_size

    print_once('Setting up datasets...')
    train_dataset_kw, train_features_kw = config.input(cfg, 'train')
    val_dataset_kw, val_features_kw = config.input(cfg, 'val')

    use_dali = args.dali_device in ('cpu', 'gpu')
    if use_dali:
        assert train_dataset_kw['ignore_offline_speed_perturbation'], \
            "DALI doesn't support offline speed perturbation"

        # pad_to_max_duration is not supported by DALI - have simple padders
        if train_features_kw['pad_to_max_duration']:
            train_feat_proc = BaseFeatures(
                pad_align=train_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=train_features_kw['max_duration'],
                sample_rate=train_features_kw['sample_rate'],
                window_size=train_features_kw['window_size'],
                window_stride=train_features_kw['window_stride'])
            train_features_kw['pad_to_max_duration'] = False
        else:
            train_feat_proc = None

        if val_features_kw['pad_to_max_duration']:
            val_feat_proc = BaseFeatures(
                pad_align=val_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=val_features_kw['max_duration'],
                sample_rate=val_features_kw['sample_rate'],
                window_size=val_features_kw['window_size'],
                window_stride=val_features_kw['window_stride'])
            val_features_kw['pad_to_max_duration'] = False
        else:
            val_feat_proc = None

        train_loader = DaliDataLoader(
            gpu_id=args.local_rank,
            dataset_path=args.dataset_dir,
            config_data=train_dataset_kw,
            config_features=train_features_kw,
            json_names=args.train_manifests,
            batch_size=batch_size,
            grad_accumulation_steps=args.grad_accumulation,
            pipeline_type="train",
            device_type=args.dali_device,
            symbols=symbols)

        val_loader = DaliDataLoader(gpu_id=args.local_rank,
                                    dataset_path=args.dataset_dir,
                                    config_data=val_dataset_kw,
                                    config_features=val_features_kw,
                                    json_names=args.val_manifests,
                                    batch_size=batch_size,
                                    pipeline_type="val",
                                    device_type=args.dali_device,
                                    symbols=symbols)
    else:
        train_dataset_kw, train_features_kw = config.input(cfg, 'train')
        train_dataset = AudioDataset(args.dataset_dir, args.train_manifests,
                                     symbols, **train_dataset_kw)
        train_loader = get_data_loader(train_dataset,
                                       batch_size,
                                       multi_gpu=multi_gpu,
                                       shuffle=True,
                                       num_workers=4)
        train_feat_proc = FilterbankFeatures(**train_features_kw)

        val_dataset_kw, val_features_kw = config.input(cfg, 'val')
        val_dataset = AudioDataset(args.dataset_dir, args.val_manifests,
                                   symbols, **val_dataset_kw)
        val_loader = get_data_loader(val_dataset,
                                     batch_size,
                                     multi_gpu=multi_gpu,
                                     shuffle=False,
                                     num_workers=4,
                                     drop_last=False)
        val_feat_proc = FilterbankFeatures(**val_features_kw)

        dur = train_dataset.duration / 3600
        dur_f = train_dataset.duration_filtered / 3600
        nsampl = len(train_dataset)
        print_once(f'Training samples: {nsampl} ({dur:.1f}h, '
                   f'filtered {dur_f:.1f}h)')

    if train_feat_proc is not None:
        train_feat_proc.cuda()
    if val_feat_proc is not None:
        val_feat_proc.cuda()

    steps_per_epoch = len(train_loader) // args.grad_accumulation

    # set up the model
    model = QuartzNet(encoder_kw=config.encoder(cfg),
                      decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
    model.cuda()
    ctc_loss = CTCLossNM(n_classes=len(symbols))
    greedy_decoder = GreedyCTCDecoder()

    print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')

    # optimization
    kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
    if args.optimizer == "novograd":
        optimizer = Novograd(model.parameters(), **kw)
    elif args.optimizer == "adamw":
        optimizer = AdamW(model.parameters(), **kw)
    elif args.optimizer == 'lamb98':
        optimizer = FusedLAMB(model.parameters(),
                              betas=(0.9, 0.98),
                              eps=1e-9,
                              **kw)
    elif args.optimizer == 'fused_novograd':
        optimizer = FusedNovoGrad(model.parameters(),
                                  betas=(0.95, 0),
                                  bias_correction=False,
                                  reg_inside_moment=True,
                                  grad_averaging=False,
                                  **kw)
    else:
        raise ValueError(f'Invalid optimizer "{args.optimizer}"')

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    adjust_lr = lambda step, epoch, optimizer: lr_policy(
        step,
        epoch,
        args.lr,
        optimizer,
        steps_per_epoch=steps_per_epoch,
        warmup_epochs=args.warmup_epochs,
        hold_epochs=args.hold_epochs,
        num_epochs=args.epochs,
        policy=args.lr_policy,
        min_lr=args.min_lr,
        exp_gamma=args.lr_exp_gamma)

    if args.ema > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if multi_gpu:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    # load checkpoint
    meta = {'best_wer': 10**6, 'start_epoch': 0}
    checkpointer = Checkpointer(args.output_dir, 'QuartzNet',
                                args.keep_milestones)
    if args.resume:
        args.ckpt = checkpointer.last_checkpoint() or args.ckpt

    if args.ckpt is not None:
        checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)

    start_epoch = meta['start_epoch']
    best_wer = meta['best_wer']
    epoch = 1
    step = start_epoch * steps_per_epoch + 1

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    # training loop
    model.train()
    if args.ema > 0.0:
        mt_ema_params = init_multi_tensor_ema(model, ema_model)
    # ema_model_weight_list, model_weight_list, overflow_buf_for_ema = ema_

    # pre-allocate
    if args.pre_allocate_range is not None:
        n_feats = train_features_kw['n_filt']
        pad_align = train_features_kw['pad_align']
        a, b = args.pre_allocate_range
        for n_frames in range(a, b + pad_align, pad_align):
            print_once(
                f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...')

            feat = torch.randn(batch_size, n_feats, n_frames, device='cuda')
            feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames)
            txt = torch.randint(high=len(symbols) - 1,
                                size=(batch_size, 100),
                                device='cuda')
            txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
            with torch.cuda.amp.autocast(enabled=args.amp):
                log_probs, enc_lens = model(feat, feat_lens)
                del feat
                loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
            loss.backward()
            model.zero_grad()
    torch.cuda.empty_cache()

    bmark_stats = BenchmarkStats()

    for epoch in range(start_epoch + 1, args.epochs + 1):
        if multi_gpu and not use_dali:
            train_loader.sampler.set_epoch(epoch)

        epoch_utts = 0
        epoch_loss = 0
        accumulated_batches = 0
        epoch_start_time = time.time()
        epoch_eval_time = 0

        for batch in train_loader:

            if accumulated_batches == 0:
                step_loss = 0
                step_utts = 0
                step_start_time = time.time()

            if use_dali:
                # with DALI, the data is already on GPU
                feat, feat_lens, txt, txt_lens = batch
                if train_feat_proc is not None:
                    feat, feat_lens = train_feat_proc(feat, feat_lens)
            else:
                batch = [t.cuda(non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feat, feat_lens = train_feat_proc(audio, audio_lens)

            # Use context manager to prevent redundant accumulation of gradients
            if (multi_gpu
                    and accumulated_batches + 1 < args.grad_accumulation):
                ctx = model.no_sync()
            else:
                ctx = empty_context()

            with ctx:
                with torch.cuda.amp.autocast(enabled=args.amp):
                    log_probs, enc_lens = model(feat, feat_lens)

                    loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
                    loss /= args.grad_accumulation

                if multi_gpu:
                    reduced_loss = reduce_tensor(loss.data, world_size)
                else:
                    reduced_loss = loss

                if torch.isnan(reduced_loss).any():
                    print_once(f'WARNING: loss is NaN; skipping update')
                    continue
                else:
                    step_loss += reduced_loss.item()
                    step_utts += batch[0].size(0) * world_size
                    epoch_utts += batch[0].size(0) * world_size
                    accumulated_batches += 1

                    scaler.scale(loss).backward()

            if accumulated_batches % args.grad_accumulation == 0:
                epoch_loss += step_loss
                scaler.step(optimizer)
                scaler.update()

                adjust_lr(step, epoch, optimizer)
                optimizer.zero_grad()

                if args.ema > 0.0:
                    apply_multi_tensor_ema(args.ema, *mt_ema_params)

                if step % args.log_frequency == 0:
                    preds = greedy_decoder(log_probs)
                    wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens,
                                                    symbols)

                    if step % args.prediction_frequency == 0:
                        print_once(f'  Decoded:   {pred_utt[:90]}')
                        print_once(f'  Reference: {ref[:90]}')

                    step_time = time.time() - step_start_time
                    log(
                        (epoch, step % steps_per_epoch
                         or steps_per_epoch, steps_per_epoch), step, 'train', {
                             'loss': step_loss,
                             'wer': 100.0 * wer,
                             'throughput': step_utts / step_time,
                             'took': step_time,
                             'lrate': optimizer.param_groups[0]['lr']
                         })

                step_start_time = time.time()

                if step % args.eval_frequency == 0:
                    tik = time.time()
                    wer = evaluate(epoch, step, val_loader, val_feat_proc,
                                   symbols, model, ema_model, ctc_loss,
                                   greedy_decoder, args.amp, use_dali)

                    if wer < best_wer and epoch >= args.save_best_from:
                        checkpointer.save(model,
                                          ema_model,
                                          optimizer,
                                          scaler,
                                          epoch,
                                          step,
                                          best_wer,
                                          is_best=True)
                        best_wer = wer
                    epoch_eval_time += time.time() - tik

                step += 1
                accumulated_batches = 0
                # end of step

            # DALI iterator need to be exhausted;
            # if not using DALI, simulate drop_last=True with grad accumulation
            if not use_dali and step > steps_per_epoch * epoch:
                break

        epoch_time = time.time() - epoch_start_time
        epoch_loss /= steps_per_epoch
        log(
            (epoch, ), None, 'train_avg', {
                'throughput': epoch_utts / epoch_time,
                'took': epoch_time,
                'loss': epoch_loss
            })
        bmark_stats.update(epoch_utts, epoch_time, epoch_loss)

        if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
            checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                              best_wer)

        if 0 < args.epochs_this_job <= epoch - start_epoch:
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        # end of epoch

    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))

    if epoch == args.epochs:
        evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
                 ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)

        checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                          best_wer)
    flush_log()
예제 #9
0
def create_optimizer(args, model, filter_bias_and_bn=True):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if weight_decay and filter_bias_and_bn:
        skip = {}
        if hasattr(model, 'no_weight_decay'):
            skip = model.no_weight_decay
        parameters = add_weight_decay(model, weight_decay, skip)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]

    opt_args = dict(lr=args.lr, weight_decay=weight_decay)

    opt_args = dict(lr=args.lr, weight_decay=weight_decay)
    if hasattr(args,
               'opt_eps') and args.opt_eps is not None and opt_lower not in [
                   'sgd', 'momentum', 'fusedmomentum', 'fusedsgd'
               ]:
        opt_args['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args['betas'] = args.opt_betas

    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=True,
                              **opt_args)
    elif opt_lower == 'momentum':
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=False,
                              **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'fusedsgd':
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=True,
                             **opt_args)
    elif opt_lower == 'fusedmomentum':
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=False,
                             **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
def create_optimizer(args, model, filter_bias_and_bn=True):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if weight_decay and filter_bias_and_bn:
        skip = {}
        if hasattr(model, 'no_weight_decay'):
            skip = model.no_weight_decay()
        parameters = add_weight_decay(model, weight_decay, skip)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(lr=args.lr, weight_decay=weight_decay)
    if hasattr(args, 'opt_eps') and args.opt_eps is not None:
        opt_args['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args['betas'] = args.opt_betas
    if hasattr(args, 'opt_args') and args.opt_args is not None:
        opt_args.update(args.opt_args)

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=True,
                              **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=False,
                              **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'adamp':
        # ================================
        # optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)

        print(' ')
        print('Gradient centralization is enabled for AdamP optimizer.')
        print(' ')

        optimizer = AdamP(parameters,
                          wd_ratio=0.01,
                          nesterov=True,
                          use_gc=True,
                          gc_conv_only=True,
                          gc_loc=False,
                          **opt_args)
        # ================================
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         momentum=args.momentum,
                         nesterov=True,
                         **opt_args)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == 'adafactor':
        if not args.lr:
            opt_args['lr'] = None
        optimizer = Adafactor(parameters, **opt_args)
    elif opt_lower == 'adahessian':
        optimizer = Adahessian(parameters, **opt_args)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  alpha=0.9,
                                  momentum=args.momentum,
                                  **opt_args)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              alpha=0.9,
                              momentum=args.momentum,
                              **opt_args)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters, **opt_args)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, **opt_args)
    elif opt_lower == 'fusedsgd':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=True,
                             **opt_args)
    elif opt_lower == 'fusedmomentum':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=False,
                             **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
def create_optimizer(optimizer_config, model, master_params=None):
    """Creates optimizer and schedule from configuration

    Parameters
    ----------
    optimizer_config : dict
        Dictionary containing the configuration options for the optimizer.
    model : Model
        The network model.

    Returns
    -------
    optimizer : Optimizer
        The optimizer.
    scheduler : LRScheduler
        The learning rate scheduler.
    """

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', "_bn0.weight", "_bn1.weight", "_bn2.weight"]

    def make_params(param_optimizer, lr=None):
        params = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay': optimizer_config["weight_decay"]},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        for p in params:
            if lr is not None:
                p["lr"] = lr
        return params

    if optimizer_config.get("classifier_lr", -1) != -1:
        # Separate classifier parameters from all others
        net_params = []
        classifier_params = []
        for k, v in model.named_parameters():
            if not v.requires_grad:
                continue
            if k.find("encoder") != -1:
                net_params.append((k, v))
            else:
                classifier_params.append((k, v))
        params = []

        params.extend(make_params(classifier_params, optimizer_config["classifier_lr"]))
        params.extend(make_params(net_params))
        print("param_groups", len(params))
    else:
        param_optimizer = list(model.named_parameters())
        params = make_params(param_optimizer)
        print("param_groups", len(params))
    if optimizer_config["type"] == "SGD":
        optimizer = optim.SGD(params,
                              lr=optimizer_config["learning_rate"],
                              momentum=optimizer_config["momentum"],
                              nesterov=optimizer_config["nesterov"])

    elif optimizer_config["type"] == "Adam":
        optimizer = optim.Adam(params,
                               eps=optimizer_config.get("eps", 1e-8),
                               lr=optimizer_config["learning_rate"],
                               weight_decay=optimizer_config["weight_decay"])
    elif optimizer_config["type"] == "FusedAdam":
        optimizer = FusedAdam(params,
                               eps=optimizer_config.get("eps", 1e-8),
                               lr=optimizer_config["learning_rate"],
                               weight_decay=optimizer_config["weight_decay"])
    elif optimizer_config["type"] == "FusedNovoGrad":
        optimizer = FusedNovoGrad(params,
                                  eps=optimizer_config.get("eps", 1e-8),
                                  lr=optimizer_config["learning_rate"],
                              weight_decay=optimizer_config["weight_decay"])
    elif optimizer_config["type"] == "AdamW":
        optimizer = AdamW(params,
                          eps=optimizer_config.get("eps", 1e-8),
                          lr=optimizer_config["learning_rate"],
                          weight_decay=optimizer_config["weight_decay"])
    elif optimizer_config["type"] == "RmsProp":
        optimizer = RMSprop(params,
                            lr=optimizer_config["learning_rate"],
                            weight_decay=optimizer_config["weight_decay"])
    else:
        raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))

    if optimizer_config["schedule"]["type"] == "step":
        scheduler = LRStepScheduler(optimizer, **optimizer_config["schedule"]["params"])
    elif optimizer_config["schedule"]["type"] == "cosine":
        scheduler = CosineAnnealingLR(optimizer, **optimizer_config["schedule"]["params"])
    elif optimizer_config["schedule"]["type"] == "clr":
        scheduler = CyclicLR(optimizer, **optimizer_config["schedule"]["params"])
    elif optimizer_config["schedule"]["type"] == "multistep":
        scheduler = MultiStepLR(optimizer, **optimizer_config["schedule"]["params"])
    elif optimizer_config["schedule"]["type"] == "exponential":
        scheduler = ExponentialLRScheduler(optimizer, **optimizer_config["schedule"]["params"])
    elif optimizer_config["schedule"]["type"] == "poly":
        scheduler = PolyLR(optimizer, **optimizer_config["schedule"]["params"])
    elif optimizer_config["schedule"]["type"] == "constant":
        scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
    elif optimizer_config["schedule"]["type"] == "linear":
        def linear_lr(it):
            return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"]

        scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)

    return optimizer, scheduler