示例#1
0
    def init_optimizer(self):
        parameters = OmegaConf.to_container(self.args.optimizer.parameters,
                                            resolve=True)
        parameters = {k: v for k, v in parameters.items() if v is not None}
        parameters["params"] = self.model.parameters()

        try:
            self.optimizer = getattr(torch_optimizer, self.args.optimizer.name)
        except Exception as e:
            try:
                self.optimizer = getattr(optim, self.args.optimizer.name)
            except:
                print(
                    f"This optimizer is not implemented ({self.args.optimizer.name}), go ahead and commit it"
                )
                exit()

        self.optimizer = self.optimizer(**parameters)

        if self.args.optimizer.use_SAM:
            self.optimizer = optimizers['SAM'](base_optimizer=self.optimizer,
                                               rho=self.args.optimizer.SAM_rho)

        if self.args.optimizer.use_lookahead:
            self.optimizer = torch_optimizer.Lookahead(
                self.optimizer,
                k=self.args.optimizer.lookahead_k,
                alpha=self.args.optimizer.lookahead_alpha)
示例#2
0
def build_optimizer(model, hparams):
    optimizer_type = OptimizersTypes[hparams.optimizer]
    optimizer_opts = {} if hparams.optim_options is None else hparams.optim_options

    if optimizer_type in OptimizersTypes:
        if not all(arg in optimizers_options[optimizer_type] for arg in optimizer_opts):
            raise ValueError("You tried to pass options incompatible with {} optimizer. "
                             "Check your parameters according to the description of the optimizer:\n\n{}".
                             format(optimizer_type, optimizers[optimizer_type].__doc__))

        optimizer = optimizers[optimizer_type](
            model.parameters(),
            lr=hparams.learning_rate,
            weight_decay=hparams.weight_decay,
            **optimizer_opts
        )
    else:
        raise ValueError(f"`{optimizer_type}` is not a valid optimizer type")

    if hparams.with_lookahead:
        optimizer = torch_optimizer.Lookahead(optimizer, k=5, alpha=0.5)

    return optimizer
示例#3
0
def build_lookahead(*a, **kw):
    base = optim.Yogi(*a, **kw)
    return optim.Lookahead(base)
示例#4
0
def LookaheadYogi(*a, **kw):
    base = optim.Yogi(*a, **kw)
    return optim.Lookahead(base)
示例#5
0
def build_optimizer(cfg, model):
    name_optimizer = cfg.optimizer.type
    optimizer = None

    if name_optimizer == 'A2GradExp':
        optimizer = optim.A2GradExp(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'A2GradInc':
        optimizer = optim.A2GradInc(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'A2GradUni':
        optimizer = optim.A2GradUni(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AccSGD':
        optimizer = optim.AccSGD(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AdaBelief':
        optimizer = optim.AdaBelief(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AdaBound':
        optimizer = optim.AdaBound(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AdaMod':
        optimizer = optim.AdaMod(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Adafactor':
        optimizer = optim.Adafactor(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AdamP':
        optimizer = optim.AdamP(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AggMo':
        optimizer = optim.AggMo(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Apollo':
        optimizer = optim.Apollo(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'DiffGrad':
        optimizer = optim.DiffGrad(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Lamb':
        optimizer = optim.Lamb(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Lookahead':
        yogi = optim.Yogi(model.parameters(), lr=cfg.optimizer.lr)
        optimizer = optim.Lookahead(yogi, k=5, alpha=0.5)
    elif name_optimizer == 'NovoGrad':
        optimizer = optim.NovoGrad(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'PID':
        optimizer = optim.PID(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'QHAdam':
        optimizer = optim.QHAdam(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'QHM':
        optimizer = optim.QHM(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'RAdam':
        optimizer = optim.RAdam(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Ranger':
        optimizer = optim.Ranger(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'RangerQH':
        optimizer = optim.RangerQH(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'RangerVA':
        optimizer = optim.RangerVA(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'SGDP':
        optimizer = optim.SGDP(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'SGDW':
        optimizer = optim.SGDW(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'SWATS':
        optimizer = optim.SWATS(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Shampoo':
        optimizer = optim.Shampoo(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Yogi':
        optimizer = optim.Yogi(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=cfg.optimizer.lr,
                                     weight_decay=cfg.optimizer.weight_decay)
    elif name_optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=cfg.optimizer.lr,
                                    momentum=cfg.optimizer.momentum,
                                    weight_decay=cfg.optimizer.weight_decay)
    if optimizer is None:
        raise Exception('optimizer is wrong')
    return optimizer