Ejemplo n.º 1
0
def add_dec_adapters(dec_model: TransformerDecoder,
                     config: AdapterConfig) -> TransformerDecoder:

    # Replace specific layer with adapter-added layer
    for i in range(len(dec_model.transformer_layers)):
        dec_model.transformer_layers[i] = adapt_transformer_output(config)(
            dec_model.transformer_layers[i])

    # Freeze all parameters
    for param in dec_model.parameters():
        param.requires_grad = False

    # Unfreeze trainable parts — layer norms and adapters
    for name, sub_module in dec_model.named_modules():
        if isinstance(sub_module, (Adapter_func, nn.LayerNorm)):
            for param_name, param in sub_module.named_parameters():
                param.requires_grad = True
    return dec_model
Ejemplo n.º 2
0
class MTLAbsSummarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 bert_from_extractive=None):
        super(MTLAbsSummarizer, self).__init__()
        self.args = args
        self.device = device

        # Initial Bert
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        # Load ckpt from extractive model
        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(dict([
                (n[11:], p) for n, p in bert_from_extractive.items()
                if n.startswith('bert.model')
            ]),
                                            strict=True)

        # Default Bert
        if args.encoder == 'baseline':
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=8,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)

        # The positional embedding is 512 in original Bert, repeat it for cases > 512
        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = \
                    self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = \
                    self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        if self.args.share_emb:
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight)

        # Initial Transformer decoder
        self.decoder = TransformerDecoder(self.args.dec_layers,
                                          self.args.dec_hidden_size,
                                          heads=self.args.dec_heads,
                                          d_ff=self.args.dec_ff_size,
                                          dropout=self.args.dec_dropout,
                                          embeddings=tgt_embeddings)

        # Initial generator
        self.generator = get_generator(self.vocab_size,
                                       self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight

        # Insert Adaptor modules
        if (args.enc_adapter):
            enc_hidden_size = self.bert.model.embeddings.word_embeddings.weight.shape[
                1]
            config = AdapterConfig(
                hidden_size=enc_hidden_size,
                adapter_size=args.adapter_size,
                adapter_act=args.adapter_act,
                adapter_initializer_range=args.adapter_initializer_range)
            self.bert.model = add_enc_adapters(self.bert.model, config)
            self.bert.model = add_layer_norm(self.bert.model,
                                             d_model=enc_hidden_size,
                                             eps=args.layer_norm_eps)
        if (args.dec_adapter):
            config = AdapterConfig(
                hidden_size=args.dec_hidden_size,
                adapter_size=args.adapter_size,
                adapter_act=args.adapter_act,
                adapter_initializer_range=args.adapter_initializer_range)
            self.decoder = add_dec_adapters(self.decoder, config)
            self.decoder = add_layer_norm(self.decoder,
                                          d_model=args.dec_hidden_size,
                                          eps=args.layer_norm_eps)

            self.generator[0].weight.requires_grad = False
            self.generator[0].bias.requires_grad = False

        # Load ckpt
        def modify_ckpt_for_enc_adapter(ckpt):
            """Modifies no-adpter ckpt for adapter-equipped encoder. """
            keys_need_modified_enc = []
            for k in list(ckpt['model'].keys()):
                if ('output' in k):
                    keys_need_modified_enc.append(k)
            for mk in keys_need_modified_enc:
                ckpt['model'] = OrderedDict([
                    (mk.replace('output', 'output.self_output'),
                     v) if k == mk else (k, v)
                    for k, v in ckpt['model'].items()
                ])

        def modify_ckpt_for_dec_adapter(ckpt):
            """Modifies no-adpter ckpt for adapter-equipped decoder. """
            keys_need_modified_dec = []
            for k in list(ckpt['model'].keys()):
                if ('layers' in k):
                    keys_need_modified_dec.append(k)
            for mk in keys_need_modified_dec:
                p = mk.find('layers.')
                new_k = mk[:p + 8] + '.dec_layer' + mk[p + 8:]
                ckpt['model'] = OrderedDict([(new_k, v) if k == mk else (k, v)
                                             for k, v in ckpt['model'].items()
                                             ])

        def identify_unmatched_keys(ckpt1, ckpt2):
            """Report the unmatched keys in ckpt1 for loading ckpt2 to ckpt1. (debug use) """
            fp = open("unmatched_keys.txt", 'w')
            num = 0
            ckpt1_keys = list(ckpt1.keys())
            ckpt2_keys = list(ckpt2.keys())
            for k in ckpt1_keys:
                if not (k in ckpt2_keys) and not ("var" in k) and not (
                        "feed_forward" in k):
                    # NOTE: since var and feed_forward use shared weights from other modules
                    fp.write(k + '\n')
                    print(k)
                    num += 1
            print("# of Unmatched Keys: {}".format(num))
            fp.close()

        if checkpoint is not None:
            if (self.args.enc_adapter and self.args.ckpt_from_no_adapter):
                modify_ckpt_for_enc_adapter(checkpoint)
            if (self.args.dec_adapter and self.args.ckpt_from_no_adapter):
                modify_ckpt_for_dec_adapter(checkpoint)

            # NOTE: not strict for load model
            #identify_unmatched_keys(self.state_dict(), checkpoint['model']) # DEBUG
            self.load_state_dict(checkpoint['model'], strict=False)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                tgt_embeddings = nn.Embedding(
                    self.vocab_size,
                    self.bert.model.config.hidden_size,
                    padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.embeddings.word_embeddings.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight

        self.to(device)

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls):
        """Forward process.

        Args:
            src (tensor(batch, max_src_len_batch)):
                Source token ids.
            tgt (tensor(batch, max_tgt_len_batch)):
                Target token ids.
            segs (tensor(batch, max_src_len_batch)):
                Segement id (0 or 1) to speparate source sentences.
            clss (tensor(batch, max_cls_num_batch)):
                the position of [CLS] token.
            mask_src (tensor(batch, max_src_len_batch))
                Mask (0 or 1) for source padding tokens.
            mask_tgt (tensor(batch, max_tgt_len_batch))
                Mask (0 or 1) for target padding tokens.
            mask_cls (tensor(batch, max_cls_num_batch)):
                Mask (0 or 1) for [CLS] position.

        Returns:
            A tuple of variable:
                decoder_outputs (tensor(batch, max_tgt_len_batch, dec_hidden_dim)):
                    The hidden states from decoder.
                top_vec (tensor(batch, max_src_len_batch, enc_hidden_dim)):
                    The hidden states from encoder.
        """
        # top_vec -> tensor(batch, max_src_len_batch, enc_hidden_dim)
        top_vec = self.bert(src, segs, mask_src)

        # dec_state -> models.decoder.TransformerDecoderState
        dec_state = self.decoder.init_decoder_state(src, top_vec)

        # decoder_outputs -> tensor(batch, max_tgt_len_batch, dec_hidden_dim)
        decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state)

        return decoder_outputs, top_vec

    # [For Inner Loop]
    def _cascade_fast_weights_grad(self, fast_weights):
        """Sets fast-weight mode for adapter and layer norm modules. """
        offset = 0
        for name, sub_module in self.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                param_num = len(sub_module._parameters)
                setattr(sub_module, 'fast_weights_flag', True)
                delattr(sub_module, 'fast_weights')
                setattr(sub_module, 'fast_weights',
                        fast_weights[offset:offset + param_num])
                offset += param_num
        return offset

    # [For Outer Loop]
    def _clean_fast_weights_mode(self):
        """Cleans fast-weight mode for adapter and layer norm modules. """
        module_num = 0
        for name, sub_module in self.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                setattr(sub_module, 'fast_weights_flag', False)
                delattr(sub_module, 'fast_weights')
                setattr(sub_module, 'fast_weights', None)
                module_num += 1
        return module_num

    def _adapter_fast_weights(self):
        """Returns fast (task) weights from full model. """
        for name, sub_module in self.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.fast_weights:
                    yield param

    def _adapter_fast_weights_bert(self):
        """Returns fast (task) weights from encoder. """
        for name, sub_module in self.bert.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.fast_weights:
                    yield param

    def _adapter_fast_weights_dec(self):
        """Returns fast (task) weights from decoder. """
        for name, sub_module in self.decoder.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.fast_weights:
                    yield param

    def _adapter_vars(self):
        """Returns true (meta) parameters from full model. """
        for name, sub_module in self.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.vars:
                    yield param

    def _adapter_vars_bert(self):
        """Returns true (meta) parameters from encoder. """
        for name, sub_module in self.bert.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.vars:
                    yield param

    def _adapter_vars_dec(self):
        """Returns true (meta) parameters from decoder. """
        for name, sub_module in self.decoder.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.vars:
                    yield param