Esempio n. 1
0
    def build_optimizer(self, trn, epochs, adam_epsilon, weight_decay,
                        warmup_steps, lr, encoder_lr, **kwargs):
        model = self.model_
        encoder = model.encoder
        num_training_steps = len(trn) * epochs // self.config.get(
            'gradient_accumulation', 1)
        encoder_parameters = list(encoder.parameters())
        parameter_groups: List[Dict[str, Any]] = []

        decoders = model.decoders
        decoder_optimizers = dict()
        for k, task in self.tasks.items():
            decoder: torch.nn.Module = decoders[k]
            decoder_parameters = list(decoder.parameters())
            if task.separate_optimizer:
                decoder_optimizers[k] = task.build_optimizer(decoder=decoder,
                                                             **kwargs)
            else:
                task_lr = task.lr or lr
                parameter_groups.append({
                    "params": decoder_parameters,
                    'lr': task_lr
                })
        parameter_groups.append({
            "params": encoder_parameters,
            'lr': encoder_lr
        })
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        no_decay_parameters = set()
        for n, p in model.named_parameters():
            if any(nd in n for nd in no_decay):
                no_decay_parameters.add(p)
        no_decay_by_lr = defaultdict(list)
        for group in parameter_groups:
            _lr = group['lr']
            ps = group['params']
            group['params'] = decay_parameters = []
            group['weight_decay'] = weight_decay
            for p in ps:
                if p in no_decay_parameters:
                    no_decay_by_lr[_lr].append(p)
                else:
                    decay_parameters.append(p)
        for _lr, ps in no_decay_by_lr.items():
            parameter_groups.append({
                "params": ps,
                'lr': _lr,
                'weight_decay': 0.0
            })
        # noinspection PyTypeChecker
        encoder_optimizer = optimization.AdamW(
            parameter_groups,
            lr=lr,
            weight_decay=weight_decay,
            eps=adam_epsilon,
        )
        encoder_scheduler = optimization.get_linear_schedule_with_warmup(
            encoder_optimizer, num_training_steps * warmup_steps,
            num_training_steps)
        return encoder_optimizer, encoder_scheduler, decoder_optimizers
Esempio n. 2
0
def build_optimizer_for_pretrained(model: torch.nn.Module,
                                   pretrained: torch.nn.Module,
                                   lr=1e-5,
                                   weight_decay=0.01,
                                   eps=1e-8,
                                   transformer_lr=None,
                                   transformer_weight_decay=None,
                                   no_decay=('bias', 'LayerNorm.bias',
                                             'LayerNorm.weight'),
                                   **kwargs):
    if transformer_lr is None:
        transformer_lr = lr
    if transformer_weight_decay is None:
        transformer_weight_decay = weight_decay
    params = defaultdict(lambda: defaultdict(list))
    pretrained = set(pretrained.parameters())
    if isinstance(no_decay, tuple):

        def no_decay_fn(name):
            return any(nd in name for nd in no_decay)
    else:
        assert callable(
            no_decay), 'no_decay has to be callable or a tuple of str'
        no_decay_fn = no_decay
    for n, p in model.named_parameters():
        is_pretrained = 'pretrained' if p in pretrained else 'non_pretrained'
        is_no_decay = 'no_decay' if no_decay_fn(n) else 'decay'
        params[is_pretrained][is_no_decay].append(p)

    grouped_parameters = [
        {
            'params': params['pretrained']['decay'],
            'weight_decay': transformer_weight_decay,
            'lr': transformer_lr
        },
        {
            'params': params['pretrained']['no_decay'],
            'weight_decay': 0.0,
            'lr': transformer_lr
        },
        {
            'params': params['non_pretrained']['decay'],
            'weight_decay': weight_decay,
            'lr': lr
        },
        {
            'params': params['non_pretrained']['no_decay'],
            'weight_decay': 0.0,
            'lr': lr
        },
    ]

    return optimization.AdamW(grouped_parameters,
                              lr=lr,
                              weight_decay=weight_decay,
                              eps=eps,
                              **kwargs)