Ejemplo n.º 1
0
class Summarizer(nn.Module):
    def __init__(self, device, args):
        super(Summarizer, self).__init__()
        self.device = device
        self.bert = Bert()
        self.encoder = TransformerInterEncoder(self.bert.model.config.hidden_size,
                                               args.ff_size, args.heads,
                                               args.dropout, args.inter_layers)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)

        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
        self.to(device)

    def load_cp(self, pt):
        self.load_state_dict(pt, strict=True)

    def forward(self, x, segs, clss, mask, mask_cls, sentence_range=None):

        top_vec = self.bert(x, segs, mask)
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1)
        return sent_scores, mask_cls
    def __init__(self, args, device, load_pretrained_bert = False, bert_config = None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif(args.encoder=='transformer'):
            self.encoder = TransformerInterEncoder(self.bert.model.config.hidden_size, args.ff_size, args.heads,
                                                   args.dropout, args.inter_layers)
        elif(args.encoder=='rnn'):
            self.encoder = RNNEncoder(bidirectional=True, num_layers=1,
                                      input_size=self.bert.model.config.hidden_size, hidden_size=args.rnn_size,
                                      dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.hidden_size,
                                     num_hidden_layers=6, num_attention_heads=8, intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
Ejemplo n.º 3
0
    def __init__(self,
                 args,
                 word_padding_idx,
                 vocab_size,
                 device,
                 checkpoint=None,
                 multigpu=False):
        self.multigpu = multigpu
        super(Summarizer, self).__init__()
        self.vocab_size = vocab_size
        self.device = device

        src_embeddings = torch.nn.Embedding(self.vocab_size,
                                            args.emb_size,
                                            padding_idx=word_padding_idx)
        if (args.structured):
            self.encoder = StructuredEncoder(args.hidden_size, args.ff_size,
                                             args.heads, args.dropout,
                                             src_embeddings, args.local_layers,
                                             args.inter_layers)
        else:
            self.encoder = TransformerInterEncoder(
                args.hidden_size, args.ff_size, args.heads, args.dropout,
                src_embeddings, args.local_layers, args.inter_layers)
        if checkpoint is not None:
            # checkpoint['model']
            keys = list(checkpoint['model'].keys())
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            for p in self.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
Ejemplo n.º 4
0
    def __init__(self, device, args):
        super(Summarizer, self).__init__()
        self.device = device
        self.bert = Bert()
        self.encoder = TransformerInterEncoder(self.bert.model.config.hidden_size,
                                               args.ff_size, args.heads,
                                               args.dropout, args.inter_layers)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)

        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
        self.to(device)
Ejemplo n.º 5
0
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        if (args.freeze_initial > 0):
            for param in self.bert.model.encoder.layer[
                    0:args.freeze_initial].parameters():
                param.requires_grad = False
            print("*" * 80)
            print("*" * 80)
            print("Initial Layers of BERT is frozen, ie first ",
                  args.freeze_initial, "Layers")
            print(self.bert.model.encoder.layer[0:args.freeze_initial])
            print("*" * 80)
            print("*" * 80)

        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif (args.encoder == 'multi_layer_classifier'):
            self.encoder = MultiLayerClassifier(
                self.bert.model.config.hidden_size, 32)
        elif (args.encoder == 'transformer'):
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers)
        elif (args.encoder == 'rnn'):
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
Ejemplo n.º 6
0
 def __init__(self, opt, lang):
     super(Summarizer, self).__init__()
     self.langfac = LangFactory(lang)
     if lang == 'jp':
         dirname = 'Japanese'
     elif lang == 'en':
         dirname = 'English'
     else:
         dirname = 'Others'
     temp_dir = os.path.join('/model', dirname)
     self.bert = Bert(self.langfac.toolkit.bert_model, temp_dir=temp_dir)
     self.encoder = TransformerInterEncoder(
         self.bert.model.config.hidden_size, opt['ff_size'], opt['heads'],
         opt['dropout'], opt['inter_layers'])
Ejemplo n.º 7
0
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)

        self.transformer_encoder = TransformerInterEncoder(
            self.bert.model.config.hidden_size, args.ff_size, args.heads,
            args.dropout, args.inter_layers)
        if (args.model_name == "seq"):
            self.encoder = TransformerDecoderSeq(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers, args.use_doc)
        else:
            #if ('ctx' in args.model_name or 'base' in args.model_name):
            self.encoder = PairwiseMLP(self.bert.model.config.hidden_size,
                                       args)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
            for p in self.transformer_encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)

        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
            for p in self.transformer_encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
Ejemplo n.º 8
0
class Summarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)

        self.transformer_encoder = TransformerInterEncoder(
            self.bert.model.config.hidden_size, args.ff_size, args.heads,
            args.dropout, args.inter_layers)
        if (args.model_name == "seq"):
            self.encoder = TransformerDecoderSeq(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers, args.use_doc)
        else:
            #if ('ctx' in args.model_name or 'base' in args.model_name):
            self.encoder = PairwiseMLP(self.bert.model.config.hidden_size,
                                       args)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
            for p in self.transformer_encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)

        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
            for p in self.transformer_encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)

    def load_cp(self, pt, strict=True):
        self.load_state_dict(pt['model'], strict=strict)

    def infer_sentences(self, batch, num_sent, stats=None):
        with torch.no_grad():
            src, labels, segs = batch.src, batch.labels, batch.segs
            clss, mask, mask_cls = batch.clss, batch.mask, batch.mask_cls
            #group_idxs, pair_masks = batch.test_groups, batch.test_pair_masks
            group_idxs = batch.groups
            #shouldn't use this hit_map and mask, these are for random selected indices
            #should compute new himap

            sel_sent_idxs = torch.LongTensor([[]
                                              for i in range(batch.batch_size)
                                              ]).to(labels.device)
            sel_sent_masks = torch.LongTensor([[]
                                               for i in range(batch.batch_size)
                                               ]).to(labels.device)
            candi_masks = mask_cls.clone().detach()
            top_vec = self.bert(src, segs, mask)
            sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1),
                                clss]
            raw_sents_vec = sents_vec
            doc_emb, sents_vec = self.transformer_encoder(
                raw_sents_vec, mask_cls)

            ngram_segs = [int(x) for x in self.args.ngram_seg_count.split(',')]
            for sent_id in range(num_sent):
                hit_map = None  #initially be none
                if sent_id > 0 and self.args.model_name == 'ctx':
                    hit_map = du.get_hit_ngram(batch.src_str, sel_sent_idxs,
                                               sel_sent_masks, ngram_segs)
                sent_scores = self.encoder(doc_emb,
                                           sents_vec,
                                           sel_sent_idxs,
                                           sel_sent_masks,
                                           group_idxs,
                                           candi_masks,
                                           is_test=True,
                                           raw_sent_embs=raw_sents_vec,
                                           sel_sent_hit_map=hit_map)

                sent_scores[candi_masks == False] = float('-inf')
                #in case illegal values exceed 1000
                sent_scores = sent_scores.cpu().data.numpy()
                #print(sent_scores)
                sorted_ids = np.argsort(-sent_scores, 1)
                #batch_size, sorted_sent_ids
                cur_selected_ids = torch.tensor(
                    sorted_ids[:, 0]).unsqueeze(-1).to(labels.device)
                cur_masks = torch.ones(batch.batch_size,
                                       1).long().to(labels.device)

                sel_sent_idxs = torch.cat([sel_sent_idxs, cur_selected_ids],
                                          dim=1)
                sel_sent_masks = torch.cat([sel_sent_masks, cur_masks], dim=1)
                du.set_selected_sent_to_value(candi_masks, sel_sent_idxs,
                                              sel_sent_masks, False)

            return sel_sent_idxs, sel_sent_masks

    def forward(self,
                x,
                mask,
                segs,
                clss,
                mask_cls,
                group_idxs,
                sel_sent_idxs=None,
                sel_sent_masks=None,
                candi_sent_masks=None,
                is_test=False,
                sel_sent_hit_map=None):
        top_vec = self.bert(x, segs, mask)
        #top_vec is batch_size, sequence_length, embedding_size
        #get the embedding of each CLS symbol in the batch
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        raw_sents_vec = sents_vec
        doc_emb, sents_vec = self.transformer_encoder(raw_sents_vec, mask_cls)
        sent_scores = self.encoder(doc_emb,
                                   sents_vec,
                                   sel_sent_idxs,
                                   sel_sent_masks,
                                   group_idxs,
                                   candi_sent_masks,
                                   is_test,
                                   raw_sent_embs=raw_sents_vec,
                                   sel_sent_hit_map=sel_sent_hit_map)
        #batch_size, max_sent_count
        return sent_scores, mask_cls
Ejemplo n.º 9
0
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None,
                 topic_num=10):
        super(Summarizer, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(
            'bert-base-uncased',
            do_lower_case=True,
            never_split=('[SEP]', '[CLS]', '[PAD]', '[unused0]', '[unused1]',
                         '[unused2]', '[UNK]'),
            no_word_piece=True)
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        self.memory = Memory(device, 1, self.bert.model.config.hidden_size)
        self.key_memory = Key_memory(device, 1,
                                     self.bert.model.config.hidden_size,
                                     args.dropout)
        self.topic_predictor = Topic_predictor(
            self.bert.model.config.hidden_size,
            device,
            topic_num,
            d_ex_type=args.d_ex_type)
        # self.topic_embedding = nn.Embedding(topic_num, self.bert.model.config.hidden_size)
        # todo transform to normal weight not embedding
        self.topic_embedding, self.topic_word, self.topic_word_emb = self.get_embedding(
            self.bert.model.embeddings)
        self.topic_embedding.requires_grad = True
        self.topic_word_emb.requires_grad = True
        self.topic_embedding = self.topic_embedding.to(device)
        self.topic_word_emb = self.topic_word_emb.to(device)
        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif (args.encoder == 'transformer'):
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers)
        elif (args.encoder == 'rnn'):
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)