Example #1
0
    def _make_optimizer(optimizer_name,
                        model,
                        total_iterations,
                        decay=0,
                        amsgrad=False,
                        nesterov=False,
                        control_mode=False):
        optimizer_dict = {
            'AdamW': AdamW,
            'NadamW': NadamW,
            'SGDW': SGDW,
            'Adam': Adam,
            'Nadam': Nadam,
            'SGD': SGD
        }
        optimizer = optimizer_dict[optimizer_name]

        optimizer_kw = {}
        if 'Adam' in optimizer_name:
            optimizer_kw = {'amsgrad': amsgrad}
        elif 'SGD' in optimizer_name:
            optimizer_kw = {'nesterov': nesterov, 'momentum': .9}
        if 'Nadam' not in optimizer_name:
            optimizer_kw.update({'decay': decay})

        if not control_mode:
            wd_dict = get_weight_decays(model)
            l2_extra = [2e-5] * (len(wd_dict) - 3)
            wd = fill_dict_in_order(wd_dict, [1e-5, 1e-5, 1e-6] + l2_extra)
            lr_m = {'gru': 0.5}
            use_cosine_annealing = True
        else:
            wd, lr_m = None, None
            use_cosine_annealing = False

        if not any(
            [optimizer_name == name for name in ('Adam', 'Nadam', 'SGD')]):
            return optimizer(lr=1e-4,
                             weight_decays=wd,
                             lr_multipliers=lr_m,
                             use_cosine_annealing=use_cosine_annealing,
                             t_cur=0,
                             total_iterations=total_iterations,
                             **optimizer_kw)
        else:
            return optimizer(lr=1e-4, **optimizer_kw)
def _valid_weight_decays(model):
    weight_decays = get_weight_decays(model)
    return all(x == 0 for l1l2 in weight_decays.values() for x in l1l2)
Example #3
0
 def _valid_weight_decays(model):
     weight_decays = get_weight_decays(model)
     trues = 0
     for wd in weight_decays.values():
         trues += (wd != 0)
     return (trues == 0)