def _make_optimizer(optimizer_name, model, total_iterations, decay=0, amsgrad=False, nesterov=False, control_mode=False): optimizer_dict = { 'AdamW': AdamW, 'NadamW': NadamW, 'SGDW': SGDW, 'Adam': Adam, 'Nadam': Nadam, 'SGD': SGD } optimizer = optimizer_dict[optimizer_name] optimizer_kw = {} if 'Adam' in optimizer_name: optimizer_kw = {'amsgrad': amsgrad} elif 'SGD' in optimizer_name: optimizer_kw = {'nesterov': nesterov, 'momentum': .9} if 'Nadam' not in optimizer_name: optimizer_kw.update({'decay': decay}) if not control_mode: wd_dict = get_weight_decays(model) l2_extra = [2e-5] * (len(wd_dict) - 3) wd = fill_dict_in_order(wd_dict, [1e-5, 1e-5, 1e-6] + l2_extra) lr_m = {'gru': 0.5} use_cosine_annealing = True else: wd, lr_m = None, None use_cosine_annealing = False if not any( [optimizer_name == name for name in ('Adam', 'Nadam', 'SGD')]): return optimizer(lr=1e-4, weight_decays=wd, lr_multipliers=lr_m, use_cosine_annealing=use_cosine_annealing, t_cur=0, total_iterations=total_iterations, **optimizer_kw) else: return optimizer(lr=1e-4, **optimizer_kw)
def _valid_weight_decays(model): weight_decays = get_weight_decays(model) return all(x == 0 for l1l2 in weight_decays.values() for x in l1l2)
def _valid_weight_decays(model): weight_decays = get_weight_decays(model) trues = 0 for wd in weight_decays.values(): trues += (wd != 0) return (trues == 0)