def build_optimizer(self, trn, epochs, adam_epsilon, weight_decay, warmup_steps, lr, encoder_lr, **kwargs): model = self.model_ encoder = model.encoder num_training_steps = len(trn) * epochs // self.config.get( 'gradient_accumulation', 1) encoder_parameters = list(encoder.parameters()) parameter_groups: List[Dict[str, Any]] = [] decoders = model.decoders decoder_optimizers = dict() for k, task in self.tasks.items(): decoder: torch.nn.Module = decoders[k] decoder_parameters = list(decoder.parameters()) if task.separate_optimizer: decoder_optimizers[k] = task.build_optimizer(decoder=decoder, **kwargs) else: task_lr = task.lr or lr parameter_groups.append({ "params": decoder_parameters, 'lr': task_lr }) parameter_groups.append({ "params": encoder_parameters, 'lr': encoder_lr }) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay_parameters = set() for n, p in model.named_parameters(): if any(nd in n for nd in no_decay): no_decay_parameters.add(p) no_decay_by_lr = defaultdict(list) for group in parameter_groups: _lr = group['lr'] ps = group['params'] group['params'] = decay_parameters = [] group['weight_decay'] = weight_decay for p in ps: if p in no_decay_parameters: no_decay_by_lr[_lr].append(p) else: decay_parameters.append(p) for _lr, ps in no_decay_by_lr.items(): parameter_groups.append({ "params": ps, 'lr': _lr, 'weight_decay': 0.0 }) # noinspection PyTypeChecker encoder_optimizer = optimization.AdamW( parameter_groups, lr=lr, weight_decay=weight_decay, eps=adam_epsilon, ) encoder_scheduler = optimization.get_linear_schedule_with_warmup( encoder_optimizer, num_training_steps * warmup_steps, num_training_steps) return encoder_optimizer, encoder_scheduler, decoder_optimizers
def build_optimizer_scheduler_with_transformer(model: torch.nn.Module, transformer: torch.nn.Module, lr: float, transformer_lr: float, num_training_steps: int, warmup_steps: Union[float, int], weight_decay: float, adam_epsilon: float, no_decay=('bias', 'LayerNorm.bias', 'LayerNorm.weight')): optimizer = build_optimizer_for_pretrained(model, transformer, lr, weight_decay, eps=adam_epsilon, transformer_lr=transformer_lr, no_decay=no_decay) if isinstance(warmup_steps, float): assert 0 < warmup_steps < 1, 'warmup_steps has to fall in range (0, 1) when it is float.' warmup_steps = num_training_steps * warmup_steps scheduler = optimization.get_linear_schedule_with_warmup( optimizer, warmup_steps, num_training_steps) return optimizer, scheduler