Ejemplo n.º 1
0
    def __init__(self, args, device, checkpoint=None, bert_from_extractive=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(
                dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True)

        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size,
                                     num_hidden_layers=args.enc_layers, num_attention_heads=8,
                                     intermediate_size=args.enc_ff_size,
                                     hidden_dropout_prob=args.enc_dropout,
                                     attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)

        if(args.max_pos>512):
            my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
        if (self.args.share_emb):
            tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)

        self.decoder = TransformerDecoder(
            self.args.dec_layers,
            self.args.dec_hidden_size, heads=self.args.dec_heads,
            d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings)

        self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight


        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if(args.use_bert_emb):
                tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight

        self.to(device)
Ejemplo n.º 2
0
class BertSummarizer(nn.Module):
    def __init__(self, checkpoint, device, temp_dir='/temp'):
        super(BertSummarizer, self).__init__()
        self.device = device
        self.bert = Bert(False, temp_dir, True)

        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)

        self.decoder = TransformerDecoder(6,
                                          768,
                                          heads=8,
                                          d_ff=2048,
                                          dropout=0.2,
                                          embeddings=tgt_embeddings)
        self.generator = get_generator(self.vocab_size, 768, self.device)
        self.generator[0].weight = self.decoder.embeddings.weight

        self.load_state_dict(checkpoint['model'], strict=True)
        self.to(self.device)

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls):
        top_vec = self.bert(src, segs, mask_src)
        dec_state = self.decoder.init_decoder_state(src, top_vec)
        decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state)
        return decoder_outputs, None
Ejemplo n.º 3
0
def add_dec_adapters(dec_model: TransformerDecoder,
                     config: AdapterConfig) -> TransformerDecoder:

    # Replace specific layer with adapter-added layer
    for i in range(len(dec_model.transformer_layers)):
        dec_model.transformer_layers[i] = adapt_transformer_output(config)(
            dec_model.transformer_layers[i])

    # Freeze all parameters
    for param in dec_model.parameters():
        param.requires_grad = False

    # Unfreeze trainable parts — layer norms and adapters
    for name, sub_module in dec_model.named_modules():
        if isinstance(sub_module, (Adapter_func, nn.LayerNorm)):
            for param_name, param in sub_module.named_parameters():
                param.requires_grad = True
    return dec_model
Ejemplo n.º 4
0
    def __init__(self, checkpoint, device, temp_dir='/temp'):
        super(BertSummarizer, self).__init__()
        self.device = device
        self.bert = Bert(False, temp_dir, True)

        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)

        self.decoder = TransformerDecoder(6,
                                          768,
                                          heads=8,
                                          d_ff=2048,
                                          dropout=0.2,
                                          embeddings=tgt_embeddings)
        self.generator = get_generator(self.vocab_size, 768, self.device)
        self.generator[0].weight = self.decoder.embeddings.weight

        self.load_state_dict(checkpoint['model'], strict=True)
        self.to(self.device)
Ejemplo n.º 5
0
 def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers, use_doc=False, aggr='last'):
     super(TransformerDecoderSeq, self).__init__()
     self.aggr = aggr
     self.decoder = TransformerDecoder(d_model, d_ff, heads, dropout, num_inter_layers)
     self.dropout = nn.Dropout(dropout)
     self.use_doc = use_doc
     if self.use_doc:
         self.linear_doc1 = nn.Linear(2 * d_model, d_model)
         self.linear_doc2 = nn.Linear(d_model, 1)
         self.bilinear = nn.Bilinear(d_model, d_model, 1)
     self.linear_sent1 = nn.Linear(2 * d_model, d_model)
     self.linear_sent2 = nn.Linear(d_model, 1)
     self.linear = nn.Linear(2, 1)
     self.start_emb = torch.nn.Parameter(torch.rand(1, d_model))
Ejemplo n.º 6
0
    def __init__(self,
                 max_length,
                 enc_vocab,
                 dec_vocab,
                 enc_emb_size,
                 dec_emb_size,
                 enc_units,
                 dec_units,
                 dropout_rate=0.1):
        super(Transformer, self).__init__()
        enc_vocab_size = len(enc_vocab.itos)
        dec_vocab_size = len(dec_vocab.itos)

        self.encoder_embedding = nn.Sequential(
            TransformerEmbedding(vocab_size=enc_vocab_size,
                                 padding_idx=enc_vocab.stoi["<pad>"],
                                 max_length=max_length,
                                 embedding_size=enc_emb_size),
            nn.Dropout(p=dropout_rate))
        self.decoder_embedding = nn.Sequential(
            TransformerEmbedding(vocab_size=dec_vocab_size,
                                 padding_idx=enc_vocab.stoi["<pad>"],
                                 max_length=max_length,
                                 embedding_size=dec_emb_size),
            nn.Dropout(p=dropout_rate))

        self.encoder = nn.Sequential(
            TransformerEncoder(enc_emb_size, enc_units),
            nn.Dropout(p=dropout_rate))
        self.decoder = TransformerDecoder(dec_emb_size, enc_emb_size,
                                          dec_units)
        self.decoder_drop = nn.Dropout(p=dropout_rate)

        self.output_layer = nn.Linear(in_features=enc_units[-1],
                                      out_features=dec_vocab_size)
        self.softmax = nn.Softmax(dim=-1)
Ejemplo n.º 7
0
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 bert_from_extractive=None):
        super(MTLAbsSummarizer, self).__init__()
        self.args = args
        self.device = device

        # Initial Bert
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        # Load ckpt from extractive model
        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(dict([
                (n[11:], p) for n, p in bert_from_extractive.items()
                if n.startswith('bert.model')
            ]),
                                            strict=True)

        # Default Bert
        if args.encoder == 'baseline':
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=8,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)

        # The positional embedding is 512 in original Bert, repeat it for cases > 512
        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = \
                    self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = \
                    self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        if self.args.share_emb:
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight)

        # Initial Transformer decoder
        self.decoder = TransformerDecoder(self.args.dec_layers,
                                          self.args.dec_hidden_size,
                                          heads=self.args.dec_heads,
                                          d_ff=self.args.dec_ff_size,
                                          dropout=self.args.dec_dropout,
                                          embeddings=tgt_embeddings)

        # Initial generator
        self.generator = get_generator(self.vocab_size,
                                       self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight

        # Insert Adaptor modules
        if (args.enc_adapter):
            enc_hidden_size = self.bert.model.embeddings.word_embeddings.weight.shape[
                1]
            config = AdapterConfig(
                hidden_size=enc_hidden_size,
                adapter_size=args.adapter_size,
                adapter_act=args.adapter_act,
                adapter_initializer_range=args.adapter_initializer_range)
            self.bert.model = add_enc_adapters(self.bert.model, config)
            self.bert.model = add_layer_norm(self.bert.model,
                                             d_model=enc_hidden_size,
                                             eps=args.layer_norm_eps)
        if (args.dec_adapter):
            config = AdapterConfig(
                hidden_size=args.dec_hidden_size,
                adapter_size=args.adapter_size,
                adapter_act=args.adapter_act,
                adapter_initializer_range=args.adapter_initializer_range)
            self.decoder = add_dec_adapters(self.decoder, config)
            self.decoder = add_layer_norm(self.decoder,
                                          d_model=args.dec_hidden_size,
                                          eps=args.layer_norm_eps)

            self.generator[0].weight.requires_grad = False
            self.generator[0].bias.requires_grad = False

        # Load ckpt
        def modify_ckpt_for_enc_adapter(ckpt):
            """Modifies no-adpter ckpt for adapter-equipped encoder. """
            keys_need_modified_enc = []
            for k in list(ckpt['model'].keys()):
                if ('output' in k):
                    keys_need_modified_enc.append(k)
            for mk in keys_need_modified_enc:
                ckpt['model'] = OrderedDict([
                    (mk.replace('output', 'output.self_output'),
                     v) if k == mk else (k, v)
                    for k, v in ckpt['model'].items()
                ])

        def modify_ckpt_for_dec_adapter(ckpt):
            """Modifies no-adpter ckpt for adapter-equipped decoder. """
            keys_need_modified_dec = []
            for k in list(ckpt['model'].keys()):
                if ('layers' in k):
                    keys_need_modified_dec.append(k)
            for mk in keys_need_modified_dec:
                p = mk.find('layers.')
                new_k = mk[:p + 8] + '.dec_layer' + mk[p + 8:]
                ckpt['model'] = OrderedDict([(new_k, v) if k == mk else (k, v)
                                             for k, v in ckpt['model'].items()
                                             ])

        def identify_unmatched_keys(ckpt1, ckpt2):
            """Report the unmatched keys in ckpt1 for loading ckpt2 to ckpt1. (debug use) """
            fp = open("unmatched_keys.txt", 'w')
            num = 0
            ckpt1_keys = list(ckpt1.keys())
            ckpt2_keys = list(ckpt2.keys())
            for k in ckpt1_keys:
                if not (k in ckpt2_keys) and not ("var" in k) and not (
                        "feed_forward" in k):
                    # NOTE: since var and feed_forward use shared weights from other modules
                    fp.write(k + '\n')
                    print(k)
                    num += 1
            print("# of Unmatched Keys: {}".format(num))
            fp.close()

        if checkpoint is not None:
            if (self.args.enc_adapter and self.args.ckpt_from_no_adapter):
                modify_ckpt_for_enc_adapter(checkpoint)
            if (self.args.dec_adapter and self.args.ckpt_from_no_adapter):
                modify_ckpt_for_dec_adapter(checkpoint)

            # NOTE: not strict for load model
            #identify_unmatched_keys(self.state_dict(), checkpoint['model']) # DEBUG
            self.load_state_dict(checkpoint['model'], strict=False)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                tgt_embeddings = nn.Embedding(
                    self.vocab_size,
                    self.bert.model.config.hidden_size,
                    padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.embeddings.word_embeddings.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight

        self.to(device)
Ejemplo n.º 8
0
class AbsSummarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 bert_from_extractive=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(dict([
                (n[11:], p) for n, p in bert_from_extractive.items()
                if n.startswith('bert.model')
            ]),
                                            strict=True)

        if (args.encoder == 'baseline'):
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=8,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)

        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        if (self.args.share_emb):
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight)

        self.decoder = TransformerDecoder(
            self.args.dec_layers,
            self.args.dec_hidden_size,
            heads=self.args.dec_heads,
            d_ff=self.args.dec_ff_size,
            dropout=self.args.dec_dropout,
            embeddings=tgt_embeddings,
            use_universal_transformer=args.dec_universal_trans)

        self.generator = get_generator(self.vocab_size,
                                       self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                tgt_embeddings = nn.Embedding(
                    self.vocab_size,
                    self.bert.model.config.hidden_size,
                    padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.embeddings.word_embeddings.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight

        self.to(device)

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls):
        # here src, tgt, mask_src directly goes into the BERT Model
        # not sure what we can change here, and how to add out linguistic features in it.
        # Therefore, we will now focus on changing things in decoder as it is trained from scratch.
        top_vec = self.bert(src, segs, mask_src)
        dec_state = self.decoder.init_decoder_state(src, top_vec)
        decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state)
        return decoder_outputs, None
Ejemplo n.º 9
0
    def __init__(self, args, device, checkpoint=None, from_extractive=None, symbols=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.symbols = symbols

        # TODO: 根据args.encoder是bert还是xlnet进行区分, 构建encoder
        self.pre_model = pre_models(args.encoder)  # 选出bert或者xlnet的类
        self.encoder = self.pre_model(args, args.large, args.temp_dir, args.finetune_encoder, self.symbols)  # encoder is bert or xlnet
        # self.decoder = XLNet(args.large, args.temp_dir, args.finetune_encoder)  # decoder is xlnet

        if args.max_pos > 512:
            if args.encoder == 'bert':
                my_pos_embeddings = nn.Embedding(args.max_pos, self.encoder.model.config.hidden_size)
                my_pos_embeddings.weight.data[:512] = self.encoder.model.embeddings.position_embeddings.weight.data
                my_pos_embeddings.weight.data[512:] = self.encoder.model.embeddings.position_embeddings.weight.data[-1][
                                                      None, :].repeat(args.max_pos - 512, 1)
                self.encoder.model.embeddings.position_embeddings = my_pos_embeddings


        if from_extractive is not None:
            self.encoder.model.load_state_dict(
                dict([(n[11:], p) for n, p in from_extractive.items() if n.startswith(args.encoder + '.model')]), strict=True)

        self.vocab_size = self.encoder.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0)
        if self.args.share_emb:
            tgt_embeddings.weight = copy.deepcopy(self.encoder.model.word_embedding.weight)

        # TODO: create decoder, options: TransformerDecoder, XLNet, GPT-2
        self.decoder = TransformerDecoder(
            self.args.dec_layers,
            self.args.dec_hidden_size, heads=self.args.dec_heads,
            d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings)

        # TODO: create generator, options: GPT-2, XLNet
        self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if args.use_pre_emb:
                if args.encoder == 'bert':
                    tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0)
                    tgt_embeddings.weight = copy.deepcopy(self.encoder.model.embeddings.word_embeddings.weight)
                if args.encoder == 'xlnet':
                    tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.d_model, padding_idx=0)
                    tgt_embeddings.weight = copy.deepcopy(self.encoder.model.word_embedding.weight)

                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight

        self.to(device)
Ejemplo n.º 10
0
class MTLAbsSummarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 bert_from_extractive=None):
        super(MTLAbsSummarizer, self).__init__()
        self.args = args
        self.device = device

        # Initial Bert
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        # Load ckpt from extractive model
        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(dict([
                (n[11:], p) for n, p in bert_from_extractive.items()
                if n.startswith('bert.model')
            ]),
                                            strict=True)

        # Default Bert
        if args.encoder == 'baseline':
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=8,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)

        # The positional embedding is 512 in original Bert, repeat it for cases > 512
        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = \
                    self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = \
                    self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        if self.args.share_emb:
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight)

        # Initial Transformer decoder
        self.decoder = TransformerDecoder(self.args.dec_layers,
                                          self.args.dec_hidden_size,
                                          heads=self.args.dec_heads,
                                          d_ff=self.args.dec_ff_size,
                                          dropout=self.args.dec_dropout,
                                          embeddings=tgt_embeddings)

        # Initial generator
        self.generator = get_generator(self.vocab_size,
                                       self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight

        # Insert Adaptor modules
        if (args.enc_adapter):
            enc_hidden_size = self.bert.model.embeddings.word_embeddings.weight.shape[
                1]
            config = AdapterConfig(
                hidden_size=enc_hidden_size,
                adapter_size=args.adapter_size,
                adapter_act=args.adapter_act,
                adapter_initializer_range=args.adapter_initializer_range)
            self.bert.model = add_enc_adapters(self.bert.model, config)
            self.bert.model = add_layer_norm(self.bert.model,
                                             d_model=enc_hidden_size,
                                             eps=args.layer_norm_eps)
        if (args.dec_adapter):
            config = AdapterConfig(
                hidden_size=args.dec_hidden_size,
                adapter_size=args.adapter_size,
                adapter_act=args.adapter_act,
                adapter_initializer_range=args.adapter_initializer_range)
            self.decoder = add_dec_adapters(self.decoder, config)
            self.decoder = add_layer_norm(self.decoder,
                                          d_model=args.dec_hidden_size,
                                          eps=args.layer_norm_eps)

            self.generator[0].weight.requires_grad = False
            self.generator[0].bias.requires_grad = False

        # Load ckpt
        def modify_ckpt_for_enc_adapter(ckpt):
            """Modifies no-adpter ckpt for adapter-equipped encoder. """
            keys_need_modified_enc = []
            for k in list(ckpt['model'].keys()):
                if ('output' in k):
                    keys_need_modified_enc.append(k)
            for mk in keys_need_modified_enc:
                ckpt['model'] = OrderedDict([
                    (mk.replace('output', 'output.self_output'),
                     v) if k == mk else (k, v)
                    for k, v in ckpt['model'].items()
                ])

        def modify_ckpt_for_dec_adapter(ckpt):
            """Modifies no-adpter ckpt for adapter-equipped decoder. """
            keys_need_modified_dec = []
            for k in list(ckpt['model'].keys()):
                if ('layers' in k):
                    keys_need_modified_dec.append(k)
            for mk in keys_need_modified_dec:
                p = mk.find('layers.')
                new_k = mk[:p + 8] + '.dec_layer' + mk[p + 8:]
                ckpt['model'] = OrderedDict([(new_k, v) if k == mk else (k, v)
                                             for k, v in ckpt['model'].items()
                                             ])

        def identify_unmatched_keys(ckpt1, ckpt2):
            """Report the unmatched keys in ckpt1 for loading ckpt2 to ckpt1. (debug use) """
            fp = open("unmatched_keys.txt", 'w')
            num = 0
            ckpt1_keys = list(ckpt1.keys())
            ckpt2_keys = list(ckpt2.keys())
            for k in ckpt1_keys:
                if not (k in ckpt2_keys) and not ("var" in k) and not (
                        "feed_forward" in k):
                    # NOTE: since var and feed_forward use shared weights from other modules
                    fp.write(k + '\n')
                    print(k)
                    num += 1
            print("# of Unmatched Keys: {}".format(num))
            fp.close()

        if checkpoint is not None:
            if (self.args.enc_adapter and self.args.ckpt_from_no_adapter):
                modify_ckpt_for_enc_adapter(checkpoint)
            if (self.args.dec_adapter and self.args.ckpt_from_no_adapter):
                modify_ckpt_for_dec_adapter(checkpoint)

            # NOTE: not strict for load model
            #identify_unmatched_keys(self.state_dict(), checkpoint['model']) # DEBUG
            self.load_state_dict(checkpoint['model'], strict=False)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                tgt_embeddings = nn.Embedding(
                    self.vocab_size,
                    self.bert.model.config.hidden_size,
                    padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.embeddings.word_embeddings.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight

        self.to(device)

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls):
        """Forward process.

        Args:
            src (tensor(batch, max_src_len_batch)):
                Source token ids.
            tgt (tensor(batch, max_tgt_len_batch)):
                Target token ids.
            segs (tensor(batch, max_src_len_batch)):
                Segement id (0 or 1) to speparate source sentences.
            clss (tensor(batch, max_cls_num_batch)):
                the position of [CLS] token.
            mask_src (tensor(batch, max_src_len_batch))
                Mask (0 or 1) for source padding tokens.
            mask_tgt (tensor(batch, max_tgt_len_batch))
                Mask (0 or 1) for target padding tokens.
            mask_cls (tensor(batch, max_cls_num_batch)):
                Mask (0 or 1) for [CLS] position.

        Returns:
            A tuple of variable:
                decoder_outputs (tensor(batch, max_tgt_len_batch, dec_hidden_dim)):
                    The hidden states from decoder.
                top_vec (tensor(batch, max_src_len_batch, enc_hidden_dim)):
                    The hidden states from encoder.
        """
        # top_vec -> tensor(batch, max_src_len_batch, enc_hidden_dim)
        top_vec = self.bert(src, segs, mask_src)

        # dec_state -> models.decoder.TransformerDecoderState
        dec_state = self.decoder.init_decoder_state(src, top_vec)

        # decoder_outputs -> tensor(batch, max_tgt_len_batch, dec_hidden_dim)
        decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state)

        return decoder_outputs, top_vec

    # [For Inner Loop]
    def _cascade_fast_weights_grad(self, fast_weights):
        """Sets fast-weight mode for adapter and layer norm modules. """
        offset = 0
        for name, sub_module in self.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                param_num = len(sub_module._parameters)
                setattr(sub_module, 'fast_weights_flag', True)
                delattr(sub_module, 'fast_weights')
                setattr(sub_module, 'fast_weights',
                        fast_weights[offset:offset + param_num])
                offset += param_num
        return offset

    # [For Outer Loop]
    def _clean_fast_weights_mode(self):
        """Cleans fast-weight mode for adapter and layer norm modules. """
        module_num = 0
        for name, sub_module in self.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                setattr(sub_module, 'fast_weights_flag', False)
                delattr(sub_module, 'fast_weights')
                setattr(sub_module, 'fast_weights', None)
                module_num += 1
        return module_num

    def _adapter_fast_weights(self):
        """Returns fast (task) weights from full model. """
        for name, sub_module in self.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.fast_weights:
                    yield param

    def _adapter_fast_weights_bert(self):
        """Returns fast (task) weights from encoder. """
        for name, sub_module in self.bert.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.fast_weights:
                    yield param

    def _adapter_fast_weights_dec(self):
        """Returns fast (task) weights from decoder. """
        for name, sub_module in self.decoder.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.fast_weights:
                    yield param

    def _adapter_vars(self):
        """Returns true (meta) parameters from full model. """
        for name, sub_module in self.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.vars:
                    yield param

    def _adapter_vars_bert(self):
        """Returns true (meta) parameters from encoder. """
        for name, sub_module in self.bert.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.vars:
                    yield param

    def _adapter_vars_dec(self):
        """Returns true (meta) parameters from decoder. """
        for name, sub_module in self.decoder.named_modules():
            if isinstance(
                    sub_module,
                (Adapter_func, LayerNorm_func)) and sub_module.trainable:
                for param in sub_module.vars:
                    yield param
Ejemplo n.º 11
0
    def __init__(self, args, device, vocab, checkpoint=None):
        super(RankAE, self).__init__()
        self.args = args
        self.device = device
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.beam_size = args.beam_size
        self.max_length = args.max_length
        self.min_length = args.min_length

        self.start_token = vocab['[unused1]']
        self.end_token = vocab['[unused2]']
        self.pad_token = vocab['[PAD]']
        self.mask_token = vocab['[MASK]']
        self.seg_token = vocab['[unused3]']
        self.cls_token = vocab['[CLS]']

        self.hidden_size = args.enc_hidden_size
        self.embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0)

        if args.encoder == 'bert':
            self.encoder = Bert(args.bert_dir, args.finetune_bert)
            if(args.max_pos > 512):
                my_pos_embeddings = nn.Embedding(args.max_pos, self.encoder.model.config.hidden_size)
                my_pos_embeddings.weight.data[:512] = self.encoder.model.embeddings.position_embeddings.weight.data
                my_pos_embeddings.weight.data[512:] = self.encoder.model.embeddings.position_embeddings.weight.data[-1][None, :].repeat(args.max_pos-512, 1)
                self.encoder.model.embeddings.position_embeddings = my_pos_embeddings
            tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0)
        else:
            self.encoder = TransformerEncoder(self.hidden_size, args.enc_ff_size, args.enc_heads,
                                              args.enc_dropout, args.enc_layers)
            tgt_embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0)

        self.hier_encoder = TransformerEncoder(self.hidden_size, args.hier_ff_size, args.hier_heads,
                                               args.hier_dropout, args.hier_layers)
        self.cup_bilinear = nn.Bilinear(self.hidden_size, self.hidden_size, 1)
        self.pos_emb = PositionalEncoding(0., self.hidden_size)

        self.decoder = TransformerDecoder(
            self.args.dec_layers,
            self.args.dec_hidden_size, heads=self.args.dec_heads,
            d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout,
            embeddings=tgt_embeddings)

        self.generator = Generator(self.vocab_size, self.args.dec_hidden_size, self.pad_token)

        self.generator.linear.weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.encoder == "transformer":
                for module in self.encoder.modules():
                    self._set_parameter_tf(module)
                xavier_uniform_(self.embeddings.weight)
            for module in self.decoder.modules():
                self._set_parameter_tf(module)
            for module in self.hier_encoder.modules():
                self._set_parameter_tf(module)
            for p in self.generator.parameters():
                self._set_parameter_linear(p)
            for p in self.cup_bilinear.parameters():
                self._set_parameter_linear(p)
            if args.share_emb:
                if args.encoder == 'bert':
                    self.embeddings = self.encoder.model.embeddings.word_embeddings
                    tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0)
                    tgt_embeddings.weight = copy.deepcopy(self.encoder.model.embeddings.word_embeddings.weight)
                else:
                    tgt_embeddings = self.embeddings
                self.decoder.embeddings = tgt_embeddings
                self.generator.linear.weight = self.decoder.embeddings.weight

        self.to(device)
Ejemplo n.º 12
0
class AbsSummarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 bert_from_extractive=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(dict([
                (n[11:], p) for n, p in bert_from_extractive.items()
                if n.startswith('bert.model')
            ]),
                                            strict=True)

        if (args.encoder == 'baseline'):
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=8,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)

        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        if (self.args.share_emb):
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight)

        self.decoder = TransformerDecoder(self.args.dec_layers,
                                          self.args.dec_hidden_size,
                                          heads=self.args.dec_heads,
                                          d_ff=self.args.dec_ff_size,
                                          dropout=self.args.dec_dropout,
                                          embeddings=tgt_embeddings)

        self.generator = get_generator(self.vocab_size,
                                       self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight

        self.topical_output = None  #get_topical_output(15, self.args.dec_hidden_size, device)
        # self.topical_output[0].weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                tgt_embeddings = nn.Embedding(
                    self.vocab_size,
                    self.bert.model.config.hidden_size,
                    padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.embeddings.word_embeddings.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight
                # self.topical_output[0].weight = self.decoder.embeddings.weight

        self.to(device)

    # def lda_process(self, batch):
    #     result = np.zeros((len(batch), 512))
    #
    #     for i, b in enumerate(batch):
    #         src_txt = tokenizer.convert_ids_to_tokens(b.tolist())
    #         src_txt = preprocess(' '.join(src_txt))
    #
    #         bow_vector = tm_dictionary.doc2bow(preprocess(' '.join(src_txt)))
    #
    #         article_topic = sorted(lda_model[bow_vector], key=lambda tup: -1 * tup[1])  # [0]
    #
    #         for index, value in article_topic[:1]:
    #             result[i, index] = value
    #
    #     return result

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls):
        top_vec = self.bert(src, segs, mask_src)

        # for i, b in enumerate(tgt):
        #     tgt_txt = tokenizer.convert_ids_to_tokens(b.tolist())
        #     print(tgt_txt)
        #     a = 1 + 2

        # # add small normal distributed noise
        # noise = torch.normal(torch.zeros(top_vec.shape), torch.ones(top_vec.shape) / 2)
        # noise = noise.cuda()
        # top_vec += noise

        # if self.args.use_topic_modelling:
        #     lda_res = self.lda_process(src)
        #
        #     for i1 in range(len(lda_res)):
        #         lda_res_tensor = torch.FloatTensor(lda_res[i1])
        #         for i2 in range(len(top_vec[i1])):
        #             try:
        #                 top_vec[i1, i2] += lda_res_tensor.cuda()
        #             except IndexError as err:
        #                 print(err)
        #                 print(top_vec.shape, lda_res.shape)
        #                 raise err

        dec_state = self.decoder.init_decoder_state(src, top_vec)
        decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state)

        # print('decoder', decoder_outputs.shape)

        return decoder_outputs, None
Ejemplo n.º 13
0
class AbsSummarizer(nn.Module):
	def __init__(self, args, device, checkpoint=None, bert_from_extractive=None):
		super(AbsSummarizer, self).__init__()
		self.args = args
		self.device = device
		self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

		if bert_from_extractive is not None:
			self.bert.model.load_state_dict(
				dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True)

		if (args.encoder == 'baseline'):
			bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size,
			                         num_hidden_layers=args.enc_layers, num_attention_heads=8,
			                         intermediate_size=args.enc_ff_size,
			                         hidden_dropout_prob=args.enc_dropout,
			                         attention_probs_dropout_prob=args.enc_dropout)
			self.bert.model = BertModel(bert_config)

		if (args.max_pos > 512):
			my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
			my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
			my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,
			                                      :].repeat(args.max_pos - 512, 1)
			self.bert.model.embeddings.position_embeddings = my_pos_embeddings
		self.vocab_size = self.bert.model.config.vocab_size

		self.enc_out_size = self.args.dec_hidden_size
		if self.args.use_dep:
			self.enc_out_size += 2
		if self.args.use_frame:
			self.enc_frame = nn.Linear(1, 20)
			self.frame_attn = MultiHeadedAttention(1, 20, 0.1)
			self.enc_out_size += 20
		self.enc_out = nn.Linear(self.enc_out_size, self.args.dec_hidden_size)
		self.drop = nn.Dropout(self.args.enc_dropout)
		self.layer_norm = nn.LayerNorm(self.args.dec_hidden_size, eps=1e-6)

		tgt_embeddings = nn.Embedding(self.vocab_size, self.args.dec_hidden_size, padding_idx=0)
		if (self.args.share_emb):
			tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)

		self.decoder = TransformerDecoder(
			self.args.dec_layers,
			self.args.dec_hidden_size, heads=self.args.dec_heads,
			d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings)

		self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device)
		self.generator[0].weight = self.decoder.embeddings.weight

		if checkpoint is not None:
			self.load_state_dict(checkpoint['model'], strict=True)
		else:
			for module in self.decoder.modules():
				if isinstance(module, (nn.Linear, nn.Embedding)):
					module.weight.data.normal_(mean=0.0, std=0.02)
				elif isinstance(module, nn.LayerNorm):
					module.bias.data.zero_()
					module.weight.data.fill_(1.0)
				if isinstance(module, nn.Linear) and module.bias is not None:
					module.bias.data.zero_()
			for p in self.generator.parameters():
				if p.dim() > 1:
					xavier_uniform_(p)
				else:
					p.data.zero_()
			if (args.use_bert_emb):
				tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
				tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
				self.decoder.embeddings = tgt_embeddings
				self.generator[0].weight = self.decoder.embeddings.weight

		self.to(device)

	def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, frame, dep):
		top_vec = self.bert(src, segs, mask_src)
		if self.args.use_dep:
			# dep_enmbeddings = self.enc_dep(dep[:, :, 0].unsqueeze(2).float())
			top_vec = torch.cat((top_vec, dep.float()), dim=2)
		if self.args.use_frame:
			frame_embeddings = self.enc_frame(frame.float().unsqueeze(-1))
			frame_embeddings = self.frame_attn(frame_embeddings, frame_embeddings, frame_embeddings, type="self")
			top_vec = torch.cat((top_vec, frame_embeddings), dim=2)
		top_vec = self.enc_out(top_vec)
		top_vec = self.layer_norm(top_vec)
		# top_vec = self.drop(top_vec)
		dec_state = self.decoder.init_decoder_state(src, top_vec)
		decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state)
		return decoder_outputs, None
Ejemplo n.º 14
0
class AbsSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, args.cased, args.finetune_bert)

        if (args.encoder == 'baseline'):
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=12,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)

        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        if (self.args.share_emb):
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight)

        self.decoder = TransformerDecoder(self.args.dec_layers,
                                          self.args.dec_hidden_size,
                                          heads=self.args.dec_heads,
                                          d_ff=self.args.dec_ff_size,
                                          dropout=self.args.dec_dropout,
                                          embeddings=tgt_embeddings)

        self.generator = get_generator(self.vocab_size,
                                       self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                tgt_embeddings = nn.Embedding(
                    self.vocab_size,
                    self.bert.model.config.hidden_size,
                    padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.embeddings.word_embeddings.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight

        self.to(device)

    def forward(self,
                src,
                tgt,
                segs,
                clss,
                mask_src,
                mask_tgt,
                mask_cls,
                likes=None):
        # print("attn: ",self.args.include_like_dist)
        top_vec = self.bert(src, segs, mask_src)
        # print("top_vec",top_vec.shape,top_vec.dtype)

        if self.args.include_like_dist and self.args.mode == "train":
            likes = torch.sqrt(likes.float())
            max_likes = torch.max(likes, dim=1).values.float()[:, None]
            norm_likes = (likes / max_likes)[:, :, None]
            # print("norm_likes",norm_likes.shape, norm_likes.dtype)
            top_vec = top_vec * norm_likes
        # print("new top_vec",top_vec.shape, top_vec.dtype)
        # exit
        dec_state = self.decoder.init_decoder_state(src, top_vec)
        decoder_outputs, state = self.decoder(
            tgt[:, :-1], top_vec, dec_state)  # <-- Pasar vector de grafo

        return decoder_outputs, None
Ejemplo n.º 15
0
class AbsSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint=None, bert_from_extractive=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert, args.bart)

        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(
                dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True)

        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size,
                                     num_hidden_layers=args.enc_layers, num_attention_heads=8,
                                     intermediate_size=args.enc_ff_size,
                                     hidden_dropout_prob=args.enc_dropout,
                                     attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)

        if(args.max_pos>512):
            my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
        if (self.args.share_emb):
            tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)

        # set the multi_task decoder
        if self.args.multi_task:
            self.decoder_monolingual = TransformerDecoder(
                self.args.dec_layers,
                self.args.dec_hidden_size, heads=self.args.dec_heads,
                d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings,
                sep_dec=self.args.sep_decoder)
        # if not args.bart:
        self.decoder = TransformerDecoder(
            self.args.dec_layers,
            self.args.dec_hidden_size, heads=self.args.dec_heads,
            d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings, sep_dec=self.args.sep_decoder)

        self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device)

        self.generator[0].weight = self.decoder.embeddings.weight



        # 先初始化,再读存档,避免出现错读。

        for module in self.decoder.modules():
            if isinstance(module, (nn.Linear, nn.Embedding)):
                module.weight.data.normal_(mean=0.0, std=0.02)
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        if self.args.multi_task:
            for module in self.decoder_monolingual.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
        for p in self.generator.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)
            else:
                p.data.zero_()
        if(args.use_bert_emb):
            tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
            tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
            self.decoder.embeddings = tgt_embeddings
            self.generator[0].weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            if args.few_shot and args.multi_task:
                # use one decoder to initialize two decoders
                new_states = OrderedDict()
                for each in checkpoint['model']:
                    if each.startswith('decoder'):
                        new_states[each] = copy.deepcopy(checkpoint['model'][each])
                        new_states[each.replace('decoder', 'decoder_monolingual')] = copy.deepcopy(checkpoint['model'][each])
                    else:
                        new_states[each] = copy.deepcopy(checkpoint['model'][each])
                self.load_state_dict(new_states, strict=True)

            else:
                self.load_state_dict(checkpoint['model'], strict=True)


        self.to(device)

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, tgt_segs=None, tgt_eng=None):
        top_vec = self.bert(src, segs, mask_src)
        dec_state = self.decoder.init_decoder_state(src, top_vec)
        if self.args.multi_task:
            tgt_eng_segs = torch.ones(tgt_eng.size()).long().cuda()
            mono_dec_state = self.decoder_monolingual.init_decoder_state(src, top_vec)
            mono_decoder_outputs, mono_state = self.decoder_monolingual(tgt_eng[:, :-1], top_vec, mono_dec_state,
                                                                        tgt_segs = tgt_eng_segs[:, :-1])
        else:
            mono_decoder_outputs = None
            mono_state = None

        if tgt_segs is None:
            decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state)
        else:
            decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state, tgt_segs=tgt_segs[:,:-1])
        # print("decoder_outputs = ", decoder_outputs.size())
        # print(decoder_outputs)
        # exit()
        return decoder_outputs, None, mono_decoder_outputs
Ejemplo n.º 16
0
    def __init__(self, args, device, checkpoint=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.bert_model_path, args.large, args.temp_dir,
                         args.finetune_bert)

        max_pos = args.max_pos
        if (max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        # guide-tags
        self.tag_embeddings = TiedEmbedding(args.max_n_tags,
                                            self.bert.model.config.hidden_size,
                                            padding_idx=0)
        self.tag_drop = nn.Dropout(args.tag_dropout)

        # decoder
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        self.decoder = TransformerDecoder(self.args.dec_layers,
                                          self.bert.model.config.hidden_size,
                                          heads=self.args.dec_heads,
                                          d_ff=self.args.dec_ff_size,
                                          dropout=self.args.dec_dropout,
                                          embeddings=tgt_embeddings,
                                          tag_embeddings=self.tag_embeddings)

        # generator
        self.generator = get_generator(
            args,
            self.vocab_size,
            self.bert.model.config.hidden_size,
            gen_weight=self.decoder.embeddings.weight)

        # load checkpoint or initialize the parameters
        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            self.tag_embeddings.weight.data.normal_(mean=0.0, std=0.02)
            self.tag_embeddings.weight[
                self.tag_embeddings.padding_idx].data.fill_(0)
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                self.decoder.embeddings.weight.data.copy_(
                    self.bert.model.embeddings.word_embeddings.weight)

        self.to(device)
Ejemplo n.º 17
0
class AbsSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.bert_model_path, args.large, args.temp_dir,
                         args.finetune_bert)

        max_pos = args.max_pos
        if (max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        # guide-tags
        self.tag_embeddings = TiedEmbedding(args.max_n_tags,
                                            self.bert.model.config.hidden_size,
                                            padding_idx=0)
        self.tag_drop = nn.Dropout(args.tag_dropout)

        # decoder
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        self.decoder = TransformerDecoder(self.args.dec_layers,
                                          self.bert.model.config.hidden_size,
                                          heads=self.args.dec_heads,
                                          d_ff=self.args.dec_ff_size,
                                          dropout=self.args.dec_dropout,
                                          embeddings=tgt_embeddings,
                                          tag_embeddings=self.tag_embeddings)

        # generator
        self.generator = get_generator(
            args,
            self.vocab_size,
            self.bert.model.config.hidden_size,
            gen_weight=self.decoder.embeddings.weight)

        # load checkpoint or initialize the parameters
        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            self.tag_embeddings.weight.data.normal_(mean=0.0, std=0.02)
            self.tag_embeddings.weight[
                self.tag_embeddings.padding_idx].data.fill_(0)
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                self.decoder.embeddings.weight.data.copy_(
                    self.bert.model.embeddings.word_embeddings.weight)

        self.to(device)

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls,
                tag_src, tag_tgt):
        segs_src = (1 - segs % 2) * mask_src.long()
        top_vec = self.bert(src, segs_src, mask_src)
        if self.training and self.args.sent_dropout > 0:
            idx = (torch.arange(clss.size(1), device=clss.device) +
                   1).unsqueeze(0).expand_as(clss)  # n x sents
            drop = torch.rand(
                clss.size(), dtype=torch.float,
                device=clss.device) < self.args.sent_dropout  # n x sents
            idx = idx * drop.long()
            msk_drop = torch.sum(
                (segs.unsqueeze(-2) == idx.unsqueeze(-1)).float(),
                dim=1)  # n x 512
            msk_tag = (torch.sum(tag_src, dim=2) > 0).float()  # n x 512
            msk_drop = msk_drop * (1 - msk_tag) * mask_src.float()
            top_vec = top_vec * (1 - msk_drop).unsqueeze(-1)
        tag_vec = self.tag_embeddings.matmul(tag_src)
        top_vec = top_vec + self.tag_drop(tag_vec)
        dec_state = self.decoder.init_decoder_state(src, top_vec)
        if self.training and self.args.word_dropout > 0:
            word_mask = 103
            drop = torch.rand(tgt.size(), dtype=torch.float,
                              device=tgt.device) < self.args.word_dropout
            drop = drop * mask_tgt
            tgt = torch.where(drop, tgt.new_full(tgt.size(), word_mask), tgt)
        decoder_outputs, state = self.decoder(tgt[:, :-1],
                                              top_vec,
                                              dec_state,
                                              tag=tag_tgt[:, :-1])
        return decoder_outputs, None
Ejemplo n.º 18
0
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 bert_from_extractive=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir,
                         args.finetune_bert)  #false, ../temp, ture
        #输入最多512个词(还要除掉[CLS]和[SEP]),最多两个句子合成一句。这之外的词和句子会没有对应的embedding,pooler是对cls位置编码
        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(dict([
                (n[11:], p) for n, p in bert_from_extractive.items()
                if n.startswith('bert.model')
            ]),
                                            strict=True)

        if (args.encoder == 'baseline'):  #default:bert
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=8,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)

        if (args.max_pos > 512):  #最大不大于512,故此层用不到
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        self.vocab_size = self.bert.model.config.vocab_size  #此为bert.model中config的vocab_size:21128
        tgt_embeddings = nn.Embedding(
            self.vocab_size, self.bert.model.config.hidden_size,
            padding_idx=0)  #同上hidden_size:768# #对摘要进行编码
        if (self.args.share_emb):  #False
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight)

        #bertmodel可作为特征提取过程,既此时对应的encoder,transformer作为decoder
        self.decoder = TransformerDecoder(  #多头机制
            self.args.dec_layers,
            self.args.dec_hidden_size,
            heads=self.args.dec_heads,
            d_ff=self.args.dec_ff_size,
            dropout=self.args.dec_dropout,
            embeddings=tgt_embeddings)

        self.generator = get_generator(self.vocab_size,
                                       self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight  #21168

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:  #对模型进行训练
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                tgt_embeddings = nn.Embedding(
                    self.vocab_size,
                    self.bert.model.config.hidden_size,
                    padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.embeddings.word_embeddings.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight

        self.to(device)
Ejemplo n.º 19
0
class RankAE(nn.Module):
    def __init__(self, args, device, vocab, checkpoint=None):
        super(RankAE, self).__init__()
        self.args = args
        self.device = device
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.beam_size = args.beam_size
        self.max_length = args.max_length
        self.min_length = args.min_length

        self.start_token = vocab['[unused1]']
        self.end_token = vocab['[unused2]']
        self.pad_token = vocab['[PAD]']
        self.mask_token = vocab['[MASK]']
        self.seg_token = vocab['[unused3]']
        self.cls_token = vocab['[CLS]']

        self.hidden_size = args.enc_hidden_size
        self.embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0)

        if args.encoder == 'bert':
            self.encoder = Bert(args.bert_dir, args.finetune_bert)
            if(args.max_pos > 512):
                my_pos_embeddings = nn.Embedding(args.max_pos, self.encoder.model.config.hidden_size)
                my_pos_embeddings.weight.data[:512] = self.encoder.model.embeddings.position_embeddings.weight.data
                my_pos_embeddings.weight.data[512:] = self.encoder.model.embeddings.position_embeddings.weight.data[-1][None, :].repeat(args.max_pos-512, 1)
                self.encoder.model.embeddings.position_embeddings = my_pos_embeddings
            tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0)
        else:
            self.encoder = TransformerEncoder(self.hidden_size, args.enc_ff_size, args.enc_heads,
                                              args.enc_dropout, args.enc_layers)
            tgt_embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0)

        self.hier_encoder = TransformerEncoder(self.hidden_size, args.hier_ff_size, args.hier_heads,
                                               args.hier_dropout, args.hier_layers)
        self.cup_bilinear = nn.Bilinear(self.hidden_size, self.hidden_size, 1)
        self.pos_emb = PositionalEncoding(0., self.hidden_size)

        self.decoder = TransformerDecoder(
            self.args.dec_layers,
            self.args.dec_hidden_size, heads=self.args.dec_heads,
            d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout,
            embeddings=tgt_embeddings)

        self.generator = Generator(self.vocab_size, self.args.dec_hidden_size, self.pad_token)

        self.generator.linear.weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.encoder == "transformer":
                for module in self.encoder.modules():
                    self._set_parameter_tf(module)
                xavier_uniform_(self.embeddings.weight)
            for module in self.decoder.modules():
                self._set_parameter_tf(module)
            for module in self.hier_encoder.modules():
                self._set_parameter_tf(module)
            for p in self.generator.parameters():
                self._set_parameter_linear(p)
            for p in self.cup_bilinear.parameters():
                self._set_parameter_linear(p)
            if args.share_emb:
                if args.encoder == 'bert':
                    self.embeddings = self.encoder.model.embeddings.word_embeddings
                    tgt_embeddings = nn.Embedding(self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0)
                    tgt_embeddings.weight = copy.deepcopy(self.encoder.model.embeddings.word_embeddings.weight)
                else:
                    tgt_embeddings = self.embeddings
                self.decoder.embeddings = tgt_embeddings
                self.generator.linear.weight = self.decoder.embeddings.weight

        self.to(device)

    def _set_parameter_tf(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def _set_parameter_linear(self, p):
        if p.dim() > 1:
            xavier_uniform_(p)
        else:
            p.data.zero_()

    def _rebuild_tgt(self, origin, index, sep_token=None):

        tgt_list = [torch.tensor([self.start_token], device=self.device)]
        selected = origin.index_select(0, index)
        for sent in selected:
            filted_sent = sent[sent != self.pad_token][1:]
            if sep_token is not None:
                filted_sent[-1] = sep_token
            else:
                filted_sent = filted_sent[:-1]
            tgt_list.append(filted_sent)
        new_tgt = torch.cat(tgt_list, 0)
        if sep_token is not None:
            new_tgt[-1] = self.end_token
        else:
            new_tgt = torch.cat([new_tgt, torch.tensor([self.end_token], device=self.device)], 0)
        return new_tgt

    def _build_memory_window(self, ex_segs, keep_clss, replace_clss=None, mask=None, samples=None):
        keep_cls_list = torch.split(keep_clss, ex_segs)
        window_list = []
        for ex in keep_cls_list:
            ex_pad = F.pad(ex, (0, 0, self.args.win_size, self.args.win_size)).unsqueeze(1)
            ex_context = torch.cat([ex_pad[:ex.size(0)], ex.unsqueeze(1),
                                    ex_pad[self.args.win_size*2:]], 1)
            window_list.append(ex_context)
        memory = torch.cat(window_list, 0)
        if replace_clss is not None:
            replace_cls_list = torch.split(replace_clss, ex_segs)
            window_list = []
            for ex in replace_cls_list:
                ex_pad = F.pad(ex, (0, 0, self.args.win_size, self.args.win_size)).unsqueeze(1)
                ex_context = torch.cat([ex_pad[:ex.size(0)], ex.unsqueeze(1),
                                        ex_pad[self.args.win_size*2:]], 1)
                window_list.append(ex_context)
            origin_memory = torch.cat(window_list, 0)
            sample_list = torch.split(samples, ex_segs)
            sample_tensor_list = []
            for i in range(len(ex_segs)):
                sample_index_ = torch.randint(0, samples.size(-1), [mask.size(-1)], device=self.device)
                sample_index = torch.index_select(sample_list[i], 1, sample_index_)
                sample_tensor = replace_cls_list[i][sample_index]
                sample_tensor_list.append(sample_tensor)
            sample_memory = torch.cat(sample_tensor_list, 0)
            memory = memory * (mask == 2).unsqueeze(-1).float() + \
                sample_memory * (mask == 0).unsqueeze(-1).float() + \
                origin_memory * (mask == 1).unsqueeze(-1).float()
        return memory

    def _src_add_noise(self, sent, sampled_sent, expand_ratio=0.):
        role_emb = sent[1:2]
        filted_sent = sent[sent != self.pad_token][2:]
        # filted_sent = sent[sent != self.pad_token][1:]
        rand_size = sampled_sent.size(0)
        length = max(int(filted_sent.size(0)*(1+expand_ratio)), filted_sent.size(0)+1)
        while filted_sent.size(0) < length:
            target_length = length - filted_sent.size(0)
            rand_sent = sampled_sent[random.randint(0, rand_size-1)]
            rand_sent = rand_sent[rand_sent != self.pad_token][2:]  # remove cls and role embedding
            # rand_sent = rand_sent[rand_sent != self.pad_token][1:] # no role embedding
            start_point = random.randint(0, rand_sent.size(0)-1)
            end_point = random.randint(start_point, rand_sent.size(0))
            rand_segment = rand_sent[start_point:min(end_point, start_point+10, start_point+target_length)]
            insert_point = random.randint(0, filted_sent.size(0)-1)
            filted_sent = torch.cat([filted_sent[:insert_point],
                                    rand_segment,
                                    filted_sent[insert_point:]], 0)
        # return filted_sent
        return torch.cat([role_emb, filted_sent], 0)

    def _build_noised_src(self, src, ex_segs, samples, expand_ratio=0.):
        src_list = torch.split(src, ex_segs)
        new_src_list = []
        sample_list = torch.split(samples, ex_segs)

        for i, ex in enumerate(src_list):
            for j, sent in enumerate(ex):
                sampled_sent = ex.index_select(0, sample_list[i][j])
                expanded_sent = self._src_add_noise(sent, sampled_sent, expand_ratio)
                new_src = torch.cat([torch.tensor([self.cls_token], device=self.device), expanded_sent], 0)
                new_src_list.append(new_src)

        new_src = pad_sequence(new_src_list, batch_first=True, padding_value=self.pad_token)
        new_mask = new_src.data.ne(self.pad_token)
        new_segs = torch.zeros_like(new_src)
        return new_src, new_mask, new_segs

    def _build_context_tgt(self, tgt, ex_segs, win_size=1, modify=False, mask=None):

        tgt_list = torch.split(tgt, ex_segs)
        new_tgt_list = []
        if modify and mask is not None:
            # 1 means keeping the sentence
            mask_list = torch.split(mask, ex_segs)
        for i in range(len(tgt_list)):
            sent_num = tgt_list[i].size(0)
            for j in range(sent_num):
                if modify:
                    low = j-win_size
                    up = j+win_size+1
                    index = torch.arange(low, up, device=self.device)
                    index = index[mask_list[i][j] > 0]
                else:
                    low = max(0, j-win_size)
                    up = min(sent_num, j+win_size+1)
                    index = torch.arange(low, up, device=self.device)
                new_tgt_list.append(self._rebuild_tgt(tgt_list[i], index, self.seg_token))

        new_tgt = pad_sequence(new_tgt_list, batch_first=True, padding_value=self.pad_token)

        return new_tgt

    def _build_doc_tgt(self, tgt, vec, ex_segs, win_size=1, max_k=6, sigma=1.0):

        vec_list = torch.split(vec, ex_segs)
        tgt_list = torch.split(tgt, ex_segs)

        new_tgt_list = []
        index_list = []
        shift_list = []
        accum_index = 0
        for idx in range(len(ex_segs)):
            ex_vec = vec_list[idx]
            sent_num = ex_segs[idx]
            ex_tgt = tgt_list[idx]
            tgt_length = ex_tgt[:, 1:].ne(self.pad_token).sum(dim=1).float()
            topk_ids = self._centrality_rank(ex_vec, sent_num, tgt_length, win_size, max_k, sigma)
            new_tgt_list.append(self._rebuild_tgt(ex_tgt, topk_ids, self.seg_token))
            shift_list.append(topk_ids)
            index_list.append(topk_ids + accum_index)
            accum_index += sent_num
        new_tgt = pad_sequence(new_tgt_list, batch_first=True, padding_value=self.pad_token)
        return new_tgt, index_list, shift_list

    def _centrality_rank(self, vec, sent_num, tgt_length, win_size, max_k, sigma, eta=0.5, min_length=5):

        assert vec.size(0) == sent_num
        sim = torch.sigmoid(self.cup_bilinear(vec.unsqueeze(1).expand(sent_num, sent_num, -1).contiguous(),
                                              vec.unsqueeze(0).expand(sent_num, sent_num, -1).contiguous())
                            ).squeeze().detach()
        # sim = torch.sigmoid(torch.mm(vec, vec.transpose(0, 1)))
        # sim = torch.cosine_similarity(
        #    vec.unsqueeze(1).expand(sent_num, sent_num, -1).contiguous().view(sent_num * sent_num, -1),
        #    vec.unsqueeze(0).expand(sent_num, sent_num, -1).contiguous().view(sent_num * sent_num, -1)
        # ).view(sent_num, sent_num).detach()

        # calculate sim weight
        k = min(max(sent_num // (win_size*2+1), 1), max_k)
        var = sent_num / k * 1.
        x = torch.arange(sent_num, device=self.device, dtype=torch.float).unsqueeze(0).expand_as(sim)
        u = torch.arange(sent_num, device=self.device, dtype=torch.float).unsqueeze(1)
        weight = torch.exp(-(x-u)**2 / (2. * var**2)) * (1. - torch.eye(sent_num, device=self.device))
        # weight = 1. - torch.eye(sent_num, device=self.device)
        sim[tgt_length < min_length, :] = -1e20

        # Calculate centrality and select top k sentence.
        topk_ids = torch.empty(0, dtype=torch.long, device=self.device)
        mask = torch.zeros([sent_num, sent_num], dtype=torch.float, device=self.device)
        for _ in range(k):
            mean_score = torch.sum(sim * weight, dim=1) / max(sent_num-1, 1)
            max_v, _ = torch.max(sim * weight * mask, dim=1)
            centrality = eta*mean_score - (1-eta)*max_v
            _, top_id = torch.topk(centrality, 1, dim=0, sorted=False)
            topk_ids = torch.cat([topk_ids, top_id], 0)
            sim[topk_ids, :] = -1e20
            mask[:, topk_ids] = 1.
        topk_ids, _ = torch.sort(topk_ids)
        """
        centrality = torch.sum(sim * weight, dim=1)
        _, topk_ids = torch.topk(centrality, k, dim=0, sorted=False)
        topk_ids, _ = torch.sort(topk_ids)
        """
        return topk_ids

    def _add_mask(self, src, mask_src):
        pm_index = torch.empty_like(mask_src).float().uniform_().le(self.args.mask_token_prob)
        ps_index = torch.empty_like(mask_src[:, 0]).float().uniform_().gt(self.args.select_sent_prob)
        pm_index[ps_index] = 0
        # Avoid mask [PAD]
        pm_index[(1-mask_src).byte()] = 0
        # Avoid mask [CLS]
        pm_index[:, 0] = 0
        # Avoid mask [SEG]
        pm_index[src == self.seg_token] = 0
        src[pm_index] = self.mask_token
        return src

    def _build_cup(self, bsz, ex_segs, win_size=1, negative_num=2):

        cup = torch.split(torch.arange(0, bsz, dtype=torch.long, device=self.device), ex_segs)
        tgt = torch.split(torch.ones(bsz), ex_segs)
        cup_list = []
        cup_origin_list = []
        tgt_list = []
        negative_list = []
        for i in range(len(ex_segs)):
            sent_num = ex_segs[i]
            cup_low = cup[i][0].item()
            cup_up = cup[i][sent_num-1].item()
            cup_index = cup[i].repeat(win_size*2*(negative_num+1))
            tgt_index = tgt[i].repeat(win_size*2*(negative_num+1))
            cup_origin_list.append(cup[i].repeat(win_size*2*(negative_num+1)))
            tgt_index[sent_num*win_size*2:] = 0
            for j in range(cup_index.size(0)):
                if tgt_index[j] == 1:
                    cup_temp = cup_index[j]
                    window_list = [t for t in range(max(cup_index[j]-win_size, cup_low),
                                                    min(cup_index[j]+win_size, cup_up)+1)
                                   if t != cup_index[j]]
                    cup_temp = window_list[(j // sent_num) % len(window_list)]
                else:
                    cand_list = [t for t in range(cup_low, max(cup_index[j]-win_size, cup_low))] + \
                                [t for t in range(min(cup_index[j]+win_size, cup_up), cup_up)]
                    cup_temp = cand_list[random.randint(0, len(cand_list)-1)]
                cup_index[j] = cup_temp
            negative_list.append((cup_index[sent_num*win_size*2:]-cup_low).
                                 view(negative_num*win_size*2, -1).transpose(0, 1))
            cup_list.append(cup_index)
            tgt_list.append(tgt_index)

        tgt = torch.cat(tgt_list, dim=0).float().to(self.device)
        cup_origin = torch.cat(cup_origin_list, dim=0)
        cup = torch.cat(cup_list, dim=0)
        negative_sample = torch.cat(negative_list, dim=0)

        return cup, cup_origin, tgt[cup != -1], negative_sample

    def _build_option_window(self, bsz, ex_segs, win_size=1, keep_ratio=0.1, replace_ratio=0.2):

        assert keep_ratio + replace_ratio <= 1.
        noise_ratio = 1 - keep_ratio - replace_ratio

        window_size = 2*win_size+1
        index = torch.split(torch.arange(1, bsz+1, dtype=torch.long, device=self.device), ex_segs)
        # 2 means noise addition, 1 means keep the memory, 0 means replacement
        tgt = torch.zeros([bsz, window_size], device=self.device, dtype=torch.int)
        prob = torch.empty([bsz, window_size], device=self.device).uniform_()
        tgt.masked_fill_(prob.lt(noise_ratio), 2)
        tgt.masked_fill_(prob.ge(1-keep_ratio), 1)
        tgt = torch.split(tgt, ex_segs)

        for i in range(len(ex_segs)):
            sent_num = ex_segs[i]
            index_pad = F.pad(index[i], (self.args.win_size, self.args.win_size))
            for j in range(sent_num):
                window = index_pad[j:j+window_size]
                # Avoiding that all elements are 0
                if torch.sum(tgt[i][j].byte()*(window > 0)) == 0:
                    tgt[i][j][win_size] = 2
                tgt[i][j][window == 0] = -1
        tgt = torch.cat(tgt, 0)
        return tgt

    def _fast_translate_batch(self, batch, memory_bank, max_length, init_tokens=None, memory_mask=None,
                              min_length=2, beam_size=3, ignore_mem_attn=False):

        batch_size = memory_bank.size(0)

        dec_states = self.decoder.init_decoder_state(batch.src, memory_bank, with_cache=True)

        # Tile states and memory beam_size times.
        dec_states.map_batch_fn(
            lambda state, dim: tile(state, beam_size, dim=dim))
        memory_bank = tile(memory_bank, beam_size, dim=0)
        init_tokens = tile(init_tokens, beam_size, dim=0)
        memory_mask = tile(memory_mask, beam_size, dim=0)

        batch_offset = torch.arange(
            batch_size, dtype=torch.long, device=self.device)
        beam_offset = torch.arange(
            0,
            batch_size * beam_size,
            step=beam_size,
            dtype=torch.long,
            device=self.device)

        alive_seq = torch.full(
            [batch_size * beam_size, 1],
            self.start_token,
            dtype=torch.long,
            device=self.device)

        # Give full probability to the first beam on the first step.
        topk_log_probs = (
            torch.tensor([0.0] + [float("-inf")] * (beam_size - 1),
                         device=self.device).repeat(batch_size))

        # Structure that holds finished hypotheses.
        hypotheses = [[] for _ in range(batch_size)]  # noqa: F812

        results = [[] for _ in range(batch_size)]  # noqa: F812

        for step in range(max_length):
            if step > 0:
                init_tokens = None
            # Decoder forward.
            decoder_input = alive_seq[:, -1].view(1, -1)
            decoder_input = decoder_input.transpose(0, 1)

            dec_out, dec_states, _ = self.decoder(decoder_input, memory_bank, dec_states, init_tokens, step=step,
                                                  memory_masks=memory_mask, ignore_memory_attn=ignore_mem_attn)

            # Generator forward.
            log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0))

            vocab_size = log_probs.size(-1)

            if step < min_length:
                log_probs[:, self.end_token] = -1e20

            if self.args.block_trigram:
                cur_len = alive_seq.size(1)
                if(cur_len > 3):
                    for i in range(alive_seq.size(0)):
                        fail = False
                        words = [int(w) for w in alive_seq[i]]
                        if(len(words) <= 3):
                            continue
                        trigrams = [(words[i-1], words[i], words[i+1]) for i in range(1, len(words)-1)]
                        trigram = tuple(trigrams[-1])
                        if trigram in trigrams[:-1]:
                            fail = True
                        if fail:
                            log_probs[i] = -1e20

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.args.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty

            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # Map beam_index to batch_index in the flat representation.
            batch_index = (
                    topk_beam_index
                    + beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
            select_indices = batch_index.view(-1)

            # Append last prediction.
            alive_seq = torch.cat(
                [alive_seq.index_select(0, select_indices),
                 topk_ids.view(-1, 1)], -1)

            is_finished = topk_ids.eq(self.end_token)
            if step + 1 == max_length:
                is_finished.fill_(1)
            # End condition is top beam is finished.
            end_condition = is_finished[:, 0].eq(1)
            # Save finished hypotheses.
            if is_finished.any():
                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                for i in range(is_finished.size(0)):
                    b = batch_offset[i]
                    if end_condition[i]:
                        is_finished[i].fill_(1)
                    finished_hyp = is_finished[i].nonzero().view(-1)
                    # Store finished hypotheses for this batch.
                    for j in finished_hyp:
                        hypotheses[b].append((
                            topk_scores[i, j],
                            predictions[i, j, 1:]))
                    # If the batch reached the end, save the n_best hypotheses.
                    if end_condition[i]:
                        best_hyp = sorted(
                            hypotheses[b], key=lambda x: x[0], reverse=True)
                        _, pred = best_hyp[0]
                        results[b].append(pred)
                non_finished = end_condition.eq(0).nonzero().view(-1)
                # If all sentences are translated, no need to go further.
                if len(non_finished) == 0:
                    break
                # Remove finished batches for the next step.
                topk_log_probs = topk_log_probs.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)
                alive_seq = predictions.index_select(0, non_finished) \
                    .view(-1, alive_seq.size(-1))
            # Reorder states.
            select_indices = batch_index.view(-1)
            if memory_bank is not None:
                memory_bank = memory_bank.index_select(0, select_indices)
            if memory_mask is not None:
                memory_mask = memory_mask.index_select(0, select_indices)
            if init_tokens is not None:
                init_tokens = init_tokens.index_select(0, select_indices)

            dec_states.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))

        results = [t[0] for t in results]
        return results

    def forward(self, batch):

        src = batch.src
        tgt = batch.tgt
        segs = batch.segs
        mask_src = batch.mask_src
        ex_segs = batch.ex_segs

        if self.training:
            # Sample some dialogue utterances to do auto-encoder
            ex_size = batch.src.size(0)
            ex_index = [i for i in range(ex_size)]
            random.shuffle(ex_index)
            ex_indexs = torch.tensor(ex_index, dtype=torch.long, device=self.device)
            ex_sample_indexs = ex_indexs[:max(int(ex_size * self.args.sample_ratio), 1)]

            # Get Context utterance training samples and targets
            cup_index, cup_original_index, cup_tgt, negative_samples = \
                self._build_cup(src.size(0), ex_segs, self.args.win_size, self.args.negative_sample_num)
            setattr(batch, 'cup_tgt', cup_tgt)

        option_mask = self._build_option_window(src.size(0), ex_segs, win_size=self.args.win_size,
                                                keep_ratio=self.args.ps if self.training else 1.,
                                                replace_ratio=self.args.pr if self.training else 0.)

        if self.training:
            # Build noised src
            noised_src, noised_src_mask, noised_src_segs = \
                self._build_noised_src(src, ex_segs, samples=negative_samples,
                                       expand_ratio=self.args.expand_ratio)
        # build context tgt
        context_tgt = self._build_context_tgt(tgt, ex_segs, self.args.win_size,
                                              modify=self.training, mask=option_mask)
        setattr(batch, 'context_tgt', context_tgt)

        # DAE: Randomly mask tokens
        if self.training:
            src = self._add_mask(src.clone(), mask_src)
            noised_src = self._add_mask(noised_src, noised_src_mask)

        if self.args.encoder == "bert":
            top_vec = self.encoder(src, segs, mask_src)
        else:
            src_emb = self.embeddings(src)
            top_vec = self.encoder(src_emb, 1-mask_src)
        clss = top_vec[:, 0, :]

        # Hierarchical encoder
        cls_list = torch.split(clss, ex_segs)
        cls_input = nn.utils.rnn.pad_sequence(cls_list, batch_first=True, padding_value=0.)
        cls_mask_list = [mask_src.new_zeros([length]) for length in ex_segs]
        cls_mask = nn.utils.rnn.pad_sequence(cls_mask_list, batch_first=True, padding_value=1)

        hier = self.hier_encoder(cls_input, cls_mask)
        hier = hier.view(-1, hier.size(-1))[(1-cls_mask.view(-1)).byte()]

        if self.training:

            # calculate cup score
            cup_tensor = torch.index_select(clss, 0, cup_index)
            origin_tensor = torch.index_select(clss, 0, cup_original_index)
            cup_score = torch.sigmoid(self.cup_bilinear(origin_tensor, cup_tensor)).squeeze()
            # cup_score = torch.sigmoid(origin_tensor.unsqueeze(1).bmm(cup_tensor.unsqueeze(-1)).squeeze())

            # noised src encode
            if self.args.encoder == "bert":
                noised_top_vec = self.encoder(noised_src, noised_src_segs, noised_src_mask)
            else:
                noised_src_emb = self.embeddings(noised_src)
                noised_top_vec = self.encoder(noised_src_emb, 1-noised_src_mask)
            noised_clss = noised_top_vec[:, 0, :]
            noised_cls_mem = self._build_memory_window(ex_segs, noised_clss, clss, option_mask, negative_samples)
            noised_cls_mem = self.pos_emb(noised_cls_mem)

            # sample training examples
            context_tgt_sample = torch.index_select(context_tgt, 0, ex_sample_indexs)
            noised_cls_mem_sample = torch.index_select(noised_cls_mem, 0, ex_sample_indexs)
            hier_sample = torch.index_select(hier, 0, ex_sample_indexs)
        else:
            cup_score = None

        if self.training:

            dec_state = self.decoder.init_decoder_state(noised_src, noised_cls_mem_sample)

            decode_context, _, _ = self.decoder(context_tgt_sample[:, :-1], noised_cls_mem_sample, dec_state,
                                                init_tokens=hier_sample)
            doc_data = None

            # For loss computation.
            if ex_sample_indexs is not None:
                batch.context_tgt = context_tgt_sample

        else:
            decode_context = None
            # Build paragraph tgt based on centrality rank.
            doc_tgt, doc_index, _ = self._build_doc_tgt(tgt, clss, ex_segs, self.args.win_size, self.args.ranking_max_k)
            centrality_segs = [len(iex) for iex in doc_index]
            centrality_index = [sum(centrality_segs[:i]) for i in range(len(centrality_segs)+1)]
            doc_index = torch.cat(doc_index, 0)
            setattr(batch, 'doc_tgt', doc_tgt)

            doc_hier_sample = torch.index_select(hier, 0, doc_index)

            # original cls mem
            cls_mem = self._build_memory_window(ex_segs, clss)
            cls_mem = self.pos_emb(cls_mem)
            doc_cls_mem = torch.index_select(cls_mem, 0, doc_index)

            # Context aware doc target
            context_doc_tgt = torch.index_select(context_tgt, 0, doc_index)
            setattr(batch, 'context_doc_tgt', context_doc_tgt)
            setattr(batch, 'doc_segs', centrality_index)

            doc_context_long = self._fast_translate_batch(batch, doc_cls_mem, self.max_length, init_tokens=doc_hier_sample,
                                                          min_length=2, beam_size=self.beam_size)
            doc_context_long = [torch.cat(doc_context_long[centrality_index[i]:centrality_index[i+1]], 0) for i in range(len(centrality_segs))]

            doc_data = doc_context_long

        return cup_score, decode_context, doc_data
class AbsSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint=None, bert_from_extractive=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.model_path, args.large, args.temp_dir, args.finetune_bert)

        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(
                dict(
                    [
                        (n[11:], p)
                        for n, p in bert_from_extractive.items()
                        if n.startswith("bert.model")
                    ]
                ),
                strict=True,
            )

        if args.encoder == "baseline":
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=8,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout,
            )
            self.bert.model = BertModel(bert_config)

        if args.max_pos > 512:
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size
            )
            my_pos_embeddings.weight.data[
                :512
            ] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:
            ] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
                None, :
            ].repeat(
                args.max_pos - 512, 1
            )
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(
            self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
        )
        if self.args.share_emb:
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight
            )

        self.decoder = TransformerDecoder(
            self.args.dec_layers,
            self.args.dec_hidden_size,
            heads=self.args.dec_heads,
            d_ff=self.args.dec_ff_size,
            dropout=self.args.dec_dropout,
            embeddings=tgt_embeddings,
        )

        self.generator = get_generator(
            self.vocab_size, self.args.dec_hidden_size, device
        )
        self.generator[0].weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint["model"], strict=True)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if args.use_bert_emb:
                tgt_embeddings = nn.Embedding(
                    self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
                )
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.embeddings.word_embeddings.weight
                )
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight

        self.to(device)

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls):
        top_vec = self.bert(src, segs, mask_src)

        for i in range(1, top_vec.shape[1]):
            top_vec[0][i] = torch.zeros(top_vec.shape[2])

        dec_state = self.decoder.init_decoder_state(src, top_vec)
        decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state)
        return decoder_outputs, None
Ejemplo n.º 21
0
class AbsSummarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 bert_from_extractive=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(dict([
                (n[11:], p) for n, p in bert_from_extractive.items()
                if n.startswith('bert.model')
            ]),
                                            strict=True)

        if (args.encoder == 'baseline'):
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=8,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = BertModel(bert_config)
        self.graph_encoder = graph_encoder(args, self.bert.model.embeddings)

        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        if (self.args.share_emb):
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight)

        self.decoder = TransformerDecoder(self.args.dec_layers,
                                          self.args.dec_hidden_size,
                                          heads=self.args.dec_heads,
                                          d_ff=self.args.dec_ff_size,
                                          dropout=self.args.dec_dropout,
                                          embeddings=tgt_embeddings)
        #
        for name, param in self.decoder.named_parameters():
            if name == 'fix_top':
                xavier_uniform_(param)
        #
        self.generator = get_generator(self.vocab_size,
                                       self.args.dec_hidden_size, device,
                                       args.copy)
        self.generator.voc_gen[0].weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                tgt_embeddings = nn.Embedding(
                    self.vocab_size,
                    self.bert.model.config.hidden_size,
                    padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.embeddings.word_embeddings.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator.voc_gen[
                    0].weight = self.decoder.embeddings.weight
        self.copy = args.copy
        self.to(device)

    def forward(self,
                src,
                tgt,
                segs,
                clss,
                mask_src,
                mask_tgt,
                mask_cls,
                batch=None):
        #
        gents, emask = self.graph_encoder(batch, self.bert.model.embeddings)
        #
        top_vec = self.bert(src, segs, mask_src)
        ent_top_vec = None
        if self.copy == True:
            ent_top_vec = self.bert(batch.ent_src, batch.ent_seg_ids,
                                    batch.mask_ent_src)
            # sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
            # sents_vec = sents_vec * mask_cls[:, :, None].float()

        dec_state = self.decoder.init_decoder_state(src, top_vec)
        decoder_outputs, state, src_context, graph_context = self.decoder(
            tgt[:, :-1], top_vec, dec_state, gents=gents, emask=emask)

        return decoder_outputs, None, src_context, graph_context, top_vec, ent_top_vec, emask