def build_optimizer(self, trn, epochs, adam_epsilon, weight_decay, warmup_steps, lr, encoder_lr, **kwargs): model = self.model_ encoder = model.encoder num_training_steps = len(trn) * epochs // self.config.get( 'gradient_accumulation', 1) encoder_parameters = list(encoder.parameters()) parameter_groups: List[Dict[str, Any]] = [] decoders = model.decoders decoder_optimizers = dict() for k, task in self.tasks.items(): decoder: torch.nn.Module = decoders[k] decoder_parameters = list(decoder.parameters()) if task.separate_optimizer: decoder_optimizers[k] = task.build_optimizer(decoder=decoder, **kwargs) else: task_lr = task.lr or lr parameter_groups.append({ "params": decoder_parameters, 'lr': task_lr }) parameter_groups.append({ "params": encoder_parameters, 'lr': encoder_lr }) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay_parameters = set() for n, p in model.named_parameters(): if any(nd in n for nd in no_decay): no_decay_parameters.add(p) no_decay_by_lr = defaultdict(list) for group in parameter_groups: _lr = group['lr'] ps = group['params'] group['params'] = decay_parameters = [] group['weight_decay'] = weight_decay for p in ps: if p in no_decay_parameters: no_decay_by_lr[_lr].append(p) else: decay_parameters.append(p) for _lr, ps in no_decay_by_lr.items(): parameter_groups.append({ "params": ps, 'lr': _lr, 'weight_decay': 0.0 }) # noinspection PyTypeChecker encoder_optimizer = optimization.AdamW( parameter_groups, lr=lr, weight_decay=weight_decay, eps=adam_epsilon, ) encoder_scheduler = optimization.get_linear_schedule_with_warmup( encoder_optimizer, num_training_steps * warmup_steps, num_training_steps) return encoder_optimizer, encoder_scheduler, decoder_optimizers
def build_optimizer_for_pretrained(model: torch.nn.Module, pretrained: torch.nn.Module, lr=1e-5, weight_decay=0.01, eps=1e-8, transformer_lr=None, transformer_weight_decay=None, no_decay=('bias', 'LayerNorm.bias', 'LayerNorm.weight'), **kwargs): if transformer_lr is None: transformer_lr = lr if transformer_weight_decay is None: transformer_weight_decay = weight_decay params = defaultdict(lambda: defaultdict(list)) pretrained = set(pretrained.parameters()) if isinstance(no_decay, tuple): def no_decay_fn(name): return any(nd in name for nd in no_decay) else: assert callable( no_decay), 'no_decay has to be callable or a tuple of str' no_decay_fn = no_decay for n, p in model.named_parameters(): is_pretrained = 'pretrained' if p in pretrained else 'non_pretrained' is_no_decay = 'no_decay' if no_decay_fn(n) else 'decay' params[is_pretrained][is_no_decay].append(p) grouped_parameters = [ { 'params': params['pretrained']['decay'], 'weight_decay': transformer_weight_decay, 'lr': transformer_lr }, { 'params': params['pretrained']['no_decay'], 'weight_decay': 0.0, 'lr': transformer_lr }, { 'params': params['non_pretrained']['decay'], 'weight_decay': weight_decay, 'lr': lr }, { 'params': params['non_pretrained']['no_decay'], 'weight_decay': 0.0, 'lr': lr }, ] return optimization.AdamW(grouped_parameters, lr=lr, weight_decay=weight_decay, eps=eps, **kwargs)