Пример #1
0
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        self.ext_layer = ExtTransformerEncoder(self.bert.model.config.hidden_size, args.ext_ff_size, args.ext_heads,
                                               args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if(args.max_pos>512):
            my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings


        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)
class ExtSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.model_path, args.large, args.temp_dir, args.finetune_bert)

        self.ext_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size,
            args.ext_ff_size,
            args.ext_heads,
            args.ext_dropout,
            args.ext_layers,
        )
        if args.encoder == "baseline":
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.ext_hidden_size,
                num_hidden_layers=args.ext_layers,
                num_attention_heads=args.ext_heads,
                intermediate_size=args.ext_ff_size,
            )
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if args.max_pos > 512:
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size
            )
            my_pos_embeddings.weight.data[
                :512
            ] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:
            ] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
                None, :
            ].repeat(
                args.max_pos - 512, 1
            )
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        if checkpoint is not None:
            self.load_state_dict(checkpoint["model"], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)

    def forward(self, src, segs, clss, mask_src, mask_cls):
        top_vec = self.bert(src, segs, mask_src)
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        sent_scores = self.ext_layer(sents_vec, mask_cls).squeeze(-1)
        return sent_scores, mask_cls
Пример #3
0
class Summarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        if args.encoder == "classifier":
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif args.encoder == "transformer":
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size,
                args.ff_size,
                args.heads,
                args.dropout,
                args.inter_layers,
            )
        elif args.encoder == "rnn":
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout,
            )
        elif args.encoder == "baseline":
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.hidden_size,
                num_hidden_layers=6,
                num_attention_heads=8,
                intermediate_size=args.ff_size,
            )
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)

    def load_cp(self, pt):
        self.load_state_dict(pt["model"], strict=True)

    def forward(self, x, segs, clss, mask, mask_cls, sentence_range=None):

        top_vec = self.bert(x, segs, mask)
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1)

        return sent_scores, mask_cls
    def __init__(self, args, device, load_pretrained_bert = False, bert_config = None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif(args.encoder=='transformer'):
            self.encoder = TransformerInterEncoder(self.bert.model.config.hidden_size, args.ff_size, args.heads,
                                                   args.dropout, args.inter_layers)
        elif(args.encoder=='rnn'):
            self.encoder = RNNEncoder(bidirectional=True, num_layers=1,
                                      input_size=self.bert.model.config.hidden_size, hidden_size=args.rnn_size,
                                      dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.hidden_size,
                                     num_hidden_layers=6, num_attention_heads=8, intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
Пример #5
0
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        if (args.freeze_initial > 0):
            for param in self.bert.model.encoder.layer[
                    0:args.freeze_initial].parameters():
                param.requires_grad = False
            print("*" * 80)
            print("*" * 80)
            print("Initial Layers of BERT is frozen, ie first ",
                  args.freeze_initial, "Layers")
            print(self.bert.model.encoder.layer[0:args.freeze_initial])
            print("*" * 80)
            print("*" * 80)

        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif (args.encoder == 'multi_layer_classifier'):
            self.encoder = MultiLayerClassifier(
                self.bert.model.config.hidden_size, 32)
        elif (args.encoder == 'transformer'):
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers)
        elif (args.encoder == 'rnn'):
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
Пример #6
0
class ExtSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        self.ext_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size, args.ext_ff_size,
            args.ext_heads, args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)

    def forward(self, src, segs, clss, mask_src, mask_cls):
        top_vec = self.bert(src, segs, mask_src)
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        sent_scores = self.ext_layer(sents_vec, mask_cls).squeeze(-1)
        return sent_scores, mask_cls
Пример #7
0
 def __init__(self,
              args,
              device,
              load_pretrained_bert=False,
              bert_config=None):
     super(Summarizer, self).__init__()
     self.device = device
     self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
     if args.encoder == "classifier":
         self.encoder = Classifier(self.bert.model.config.hidden_size)
     if args.encoder == "dnn":
         self.encoder = DNNEncoder(self.bert.model.config.hidden_size,
                                   args.num_units, args.num_layers)
     self.to(device)
Пример #8
0
def get_param_stamp_from_args(args):
    '''To get param-stamp a bit quicker.'''

    if args.experiment == "splitMNIST" and args.tasks > 10:
        raise ValueError(
            "Experiment 'splitMNIST' cannot have more than 10 tasks!")
    classes_per_task = 10 if args.experiment == "permMNIST" else int(
        np.floor(10 / args.tasks))
    if args.stream == "task-based":
        labels_per_batch = True if ((not args.scenario == "class")
                                    or classes_per_task == 1) else False
        label_stream = TaskBasedStream(
            n_tasks=args.tasks,
            iters_per_task=args.iters if labels_per_batch else args.iters *
            args.batch,
            labels_per_task=classes_per_task
            if args.scenario == "class" else 1)
    elif args.stream == "random":
        label_stream = RandomStream(
            labels=args.tasks *
            classes_per_task if args.scenario == "class" else args.tasks)
    else:
        raise NotImplementedError(
            "Stream type '{}' not currently implemented.".format(args.stream))
    config = prepare_datasets(
        name=args.experiment,
        n_labels=label_stream.n_labels,
        classes=(args.scenario == "class"),
        classes_per_task=classes_per_task,
        dir=args.d_dir,
        exception=(args.seed == 1),
        only_config=True,
    )
    softmax_classes = label_stream.n_labels if args.scenario == "class" else (
        classes_per_task if
        (args.scenario == "domain" or args.singlehead) else classes_per_task *
        label_stream.n_labels)
    model_name = Classifier(
        image_size=config['size'],
        image_channels=config['channels'],
        classes=softmax_classes,
        fc_layers=args.fc_lay,
        fc_units=args.fc_units,
    ).name
    param_stamp = get_param_stamp(args, model_name, verbose=False)
    return param_stamp
Пример #9
0
class Summarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        if (args.freeze_initial > 0):
            for param in self.bert.model.encoder.layer[
                    0:args.freeze_initial].parameters():
                param.requires_grad = False
            print("*" * 80)
            print("*" * 80)
            print("Initial Layers of BERT is frozen, ie first ",
                  args.freeze_initial, "Layers")
            print(self.bert.model.encoder.layer[0:args.freeze_initial])
            print("*" * 80)
            print("*" * 80)

        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif (args.encoder == 'multi_layer_classifier'):
            self.encoder = MultiLayerClassifier(
                self.bert.model.config.hidden_size, 32)
        elif (args.encoder == 'transformer'):
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers)
        elif (args.encoder == 'rnn'):
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)

    def load_cp(self, pt):
        self.load_state_dict(pt['model'], strict=True)

    def forward(self, x, segs, clss, mask, mask_cls, sentence_range=None):

        top_vec = self.bert(x, segs, mask, self.args.out_layer)
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1)
        return sent_scores, mask_cls
Пример #10
0
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None,
                 topic_num=10):
        super(Summarizer, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(
            'bert-base-uncased',
            do_lower_case=True,
            never_split=('[SEP]', '[CLS]', '[PAD]', '[unused0]', '[unused1]',
                         '[unused2]', '[UNK]'),
            no_word_piece=True)
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        self.memory = Memory(device, 1, self.bert.model.config.hidden_size)
        self.key_memory = Key_memory(device, 1,
                                     self.bert.model.config.hidden_size,
                                     args.dropout)
        self.topic_predictor = Topic_predictor(
            self.bert.model.config.hidden_size,
            device,
            topic_num,
            d_ex_type=args.d_ex_type)
        # self.topic_embedding = nn.Embedding(topic_num, self.bert.model.config.hidden_size)
        # todo transform to normal weight not embedding
        self.topic_embedding, self.topic_word, self.topic_word_emb = self.get_embedding(
            self.bert.model.embeddings)
        self.topic_embedding.requires_grad = True
        self.topic_word_emb.requires_grad = True
        self.topic_embedding = self.topic_embedding.to(device)
        self.topic_word_emb = self.topic_word_emb.to(device)
        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif (args.encoder == 'transformer'):
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers)
        elif (args.encoder == 'rnn'):
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
class Summarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif (args.encoder == 'transformer'):
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers)
        elif (args.encoder == 'rnn'):
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout)
        elif args.encoder == 'classifierDummy':
            self.encoder = ClassifierDummy(self.bert.model.config.hidden_size)

        elif args.encoder == 'gnn':
            self.encoder = Gnn(self.bert.model.config.hidden_size)

        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)

    def load_cp(self, pt):
        self.load_state_dict(pt['model'], strict=True)

    def forward(self, x, segs, clss, mask, mask_cls, sentence_range=None):

        top_vec = self.bert(x, segs, mask)
        # top_vec = top_vec[0]
        # print("Top Vec", top_vec.shape, top_vec.size(1))
        # top_vec = top_vec.unsqueeze(1)
        # print("Class", clss)
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        # print("Sent Vec {}, mask_cls: {} mas_cls_trimmed: {}".format(sents_vec.shape, mask_cls.shape, mask_cls[:, :, None].float().shape))
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        # print("Sent Vec [Before Encoder]", sents_vec.shape)
        sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1)
        # print("Final sentence scores: {}", sent_scores.shape)
        return sent_scores, mask_cls
Пример #12
0
    def __init__(self, args, device, checkpoint, lamb=0.8):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.lamb = lamb
        # if args.
        # bert
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        # Extraction layer.
        self.ext_layer = ExtTransformerEncoder(self.bert.model.config.hidden_size, args.ext_ff_size, args.ext_heads,
                                               args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if(args.max_pos>512):
            my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        # initial the parameter for infor\rel\novel.
        self.W_cont = nn.Parameter(torch.Tensor(1 ,self.bert.model.config.hidden_size))
        self.W_sim = nn.Parameter(torch.Tensor(self.bert.model.config.hidden_size, self.bert.model.config.hidden_size))
        self.Sim_layer= nn.Linear(self.bert.model.config.hidden_size,self.bert.model.config.hidden_size)
        self.W_rel = nn.Parameter(torch.Tensor(self.bert.model.config.hidden_size, self.bert.model.config.hidden_size))
        self.Rel_layer= nn.Linear(self.bert.model.config.hidden_size,self.bert.model.config.hidden_size)
        self.W_novel = nn.Parameter(torch.Tensor(self.bert.model.config.hidden_size, self.bert.model.config.hidden_size))
        self.b_matrix = nn.Parameter(torch.Tensor(1, 1))

        self.q_transform = nn.Linear(100, 1)
        self.bq = nn.Parameter(torch.Tensor(1, 1))
        self.brel = nn.Parameter(torch.Tensor(1, 1))
        self.bsim = nn.Parameter(torch.Tensor(1, 1))
        self.bcont = nn.Parameter(torch.Tensor(1, 1))

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
            print("checkpoint loaded! ")
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)
                for p in self.Rel_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)
                for p in self.Sim_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)
            nn.init.xavier_uniform_(self.bq)
            nn.init.xavier_uniform_(self.W_cont)
            nn.init.xavier_uniform_(self.W_sim)
            nn.init.xavier_uniform_(self.W_rel)
            nn.init.xavier_uniform_(self.W_novel)
            nn.init.xavier_uniform_(self.b_matrix)
            nn.init.xavier_uniform_(self.bcont)
            nn.init.xavier_uniform_(self.brel)
            nn.init.xavier_uniform_(self.bsim)
        self.to(device)
Пример #13
0
class ExtSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint, lamb=0.8):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.lamb = lamb
        # if args.
        # bert
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        # Extraction layer.
        self.ext_layer = ExtTransformerEncoder(self.bert.model.config.hidden_size, args.ext_ff_size, args.ext_heads,
                                               args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if(args.max_pos>512):
            my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        # initial the parameter for infor\rel\novel.
        self.W_cont = nn.Parameter(torch.Tensor(1 ,self.bert.model.config.hidden_size))
        self.W_sim = nn.Parameter(torch.Tensor(self.bert.model.config.hidden_size, self.bert.model.config.hidden_size))
        self.Sim_layer= nn.Linear(self.bert.model.config.hidden_size,self.bert.model.config.hidden_size)
        self.W_rel = nn.Parameter(torch.Tensor(self.bert.model.config.hidden_size, self.bert.model.config.hidden_size))
        self.Rel_layer= nn.Linear(self.bert.model.config.hidden_size,self.bert.model.config.hidden_size)
        self.W_novel = nn.Parameter(torch.Tensor(self.bert.model.config.hidden_size, self.bert.model.config.hidden_size))
        self.b_matrix = nn.Parameter(torch.Tensor(1, 1))

        self.q_transform = nn.Linear(100, 1)
        self.bq = nn.Parameter(torch.Tensor(1, 1))
        self.brel = nn.Parameter(torch.Tensor(1, 1))
        self.bsim = nn.Parameter(torch.Tensor(1, 1))
        self.bcont = nn.Parameter(torch.Tensor(1, 1))

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
            print("checkpoint loaded! ")
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)
                for p in self.Rel_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)
                for p in self.Sim_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)
            nn.init.xavier_uniform_(self.bq)
            nn.init.xavier_uniform_(self.W_cont)
            nn.init.xavier_uniform_(self.W_sim)
            nn.init.xavier_uniform_(self.W_rel)
            nn.init.xavier_uniform_(self.W_novel)
            nn.init.xavier_uniform_(self.b_matrix)
            nn.init.xavier_uniform_(self.bcont)
            nn.init.xavier_uniform_(self.brel)
            nn.init.xavier_uniform_(self.bsim)
        self.to(device)

    def cal_matrix0(self, sent_vec, mask_cls):

        mask_cls = mask_cls.unsqueeze(1).float()
        mask_my_own = torch.bmm(mask_cls.transpose(1, 2), mask_cls)
        sent_num = mask_cls.sum(dim=2).squeeze(1)
        d_rep = sent_vec.mean(dim=1).unsqueeze(1).transpose(1, 2)
        score_gather = torch.zeros(1, sent_vec.size(1)).to(self.device)
        #  for each of bach.
        for i in range(sent_vec.size(0)):
            Score_Cont = torch.mm(self.W_cont, sent_vec[i].transpose(0, 1))

            tmp_Sim = torch.mm(sent_vec[i], self.W_sim)
            Score_Sim = torch.mm(tmp_Sim, sent_vec[i].transpose(0, 1)) * mask_my_own[i]

            tmp_rel = torch.mm(sent_vec[i], self.W_rel)
            Score_rel = torch.mm(tmp_rel, d_rep[i]).transpose(0, 1)

            q = Score_rel + Score_Cont + Score_Sim + self.b_matrix
            q = q * mask_my_own[i]
            tmp_nov = torch.mm(sent_vec[i][0].unsqueeze(0), self.W_novel)

            accumulation = torch.mm(tmp_nov, nn.functional.tanh(
                ((q[0].sum() / sent_num[i]) * sent_vec[i][0]).unsqueeze(0).transpose(0, 1)))
            for j, each_row in enumerate(q):
                if j == 0:
                    continue
                q[j] = (q[j] + accumulation) * mask_cls[i]
                tmp_nov = torch.mm(sent_vec[i][j].unsqueeze(0), self.W_novel)
                accumulation += torch.mm(tmp_nov, nn.functional.tanh(
                    ((q[j].sum() / sent_num[i]) * sent_vec[i][j]).unsqueeze(0).transpose(0, 1)))
            q = nn.functional.sigmoid(q) * mask_my_own[i]

            sum_vec = q.sum(dim=0)
            D = torch.diag_embed(sum_vec)
            true_dim = int(sent_num[i])
            tmp_D = D[:true_dim, :true_dim]
            tmp_q = q[:true_dim, :true_dim]
            D_ = torch.inverse(tmp_D)
            I = torch.eye(true_dim).to(self.device)
            y = torch.ones(true_dim, 1).to(self.device) * (1.0 / true_dim)
            Final_score = torch.mm((1 - self.lamb) * torch.inverse(I - self.lamb * torch.mm(tmp_q, D_)), y).transpose(0,1)
            len_ = D.size(0) - true_dim
            tmp_zeros = torch.zeros(1, len_).to(self.device)
            Final_score = torch.cat((Final_score, tmp_zeros), dim=1)

            if i == 0:
                score_gather += Final_score
            else:
                score_gather = torch.cat((score_gather, Final_score), 0)

        return score_gather

    def forward(self, src, segs, clss, mask_src, mask_cls):
        # first bert layer get the top_vector, first dim is batch_size.
        # [batch * max_length]
        top_vec = self.bert(src, segs, mask_src)
        # get the vector of sentences.
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        # [batchsize * sentencenum * dim]

        sents_vec = self.ext_layer(sents_vec, mask_cls).squeeze(-1)
        sent_scores = self.cal_matrix0(sents_vec, mask_cls)
        # get the score of sentences, for abstractor.
        # [batchsize * sentencenum]
        if self.args.task == "ext":
            return sent_scores, mask_cls
        elif self.args.task == "hybrid":
            return sent_scores, mask_cls, sents_vec
class ExtSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        self.ext_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size, args.ext_ff_size,
            args.ext_heads, args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers,
                                     num_attention_heads=args.ext_heads,
                                     intermediate_size=args.ext_ff_size)
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)

    def forward(self, src, segs, clss, mask_src, mask_cls):
        # print('src', src.shape, src)
        # print('segs', segs.shape, segs)
        # print('clss',clss.shape, clss)
        top_vec = self.bert(src, segs, mask_src)
        # print('top_vec:',top_vec.shape , top_vec)
        # print(top_vec)
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        # print('sents_vec:', sents_vec.shape, sents_vec )
        # for k in range(sents_vec.shape[0]):
        #     art = sents_vec[k]
        #     n_sen = art.shape[0]
        #     si = torch.zeros(n_sen, n_sen)
        #     for l in range(art.shape[0]-1):
        #         for m in range(l + 1, art.shape[0]):
        #             # print(l)
        #             # print(m)
        #             si[l][m] = torch.cosine_similarity(art[l], art[m], dim=0)
        #             # print(si[l][m])
        #             si[m][l] = si[l][m]
        #     for l in range(n_sen):
        #         if sum(si[l])!= 0:
        #            si[l] = si[l] / sum(si[l])
        #     PR = torch.ones(1, n_sen) * 1 / n_sen
        #     print('PR', PR)
        #     for i in range(100):
        #         PR = 0.15 + 0.85 * PR.mm(si)
        #     print('si',si)
        #     print('PR',PR)

        sent_scores = self.ext_layer(sents_vec, mask_cls).squeeze(-1)
        # print('sent_scores:',sent_scores.shape, sent_scores)
        # print('mask_src:',mask_src)
        # print('mask_cls:',mask_cls)
        return sent_scores, mask_cls
Пример #15
0
def run(args, verbose=False):

    # Create plots- and results-directories, if needed
    if not os.path.isdir(args.r_dir):
        os.mkdir(args.r_dir)
    if not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    # If only want param-stamp, get it printed to screen and exit
    if utils.checkattr(args, "get_stamp"):
        print(get_param_stamp_from_args(args=args))
        exit()

    # Use cuda?
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")

    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)


    #-------------------------------------------------------------------------------------------------#

    #-----------------------#
    #----- DATA-STREAM -----#
    #-----------------------#

    # Find number of classes per task
    if args.experiment=="splitMNIST" and args.tasks>10:
            raise ValueError("Experiment 'splitMNIST' cannot have more than 10 tasks!")
    classes_per_task = 10 if args.experiment=="permMNIST" else int(np.floor(10/args.tasks))

    # Print information on data-stream to screen
    if verbose:
        print("\nPreparing the data-stream...")
        print(" --> {}-incremental learning".format(args.scenario))
        ti = "{} classes".format(args.tasks*classes_per_task) if args.stream=="random" and args.scenario=="class" else (
            "{} tasks, with {} classes each".format(args.tasks, classes_per_task)
        )
        print(" --> {} data stream: {}\n".format(args.stream, ti))

    # Set up the stream of labels (i.e., classes, domain or tasks) to use
    if args.stream=="task-based":
        labels_per_batch = True if ((not args.scenario=="class") or classes_per_task==1) else False
        # -in Task- & Domain-IL scenario, each label is always for entire batch
        # -in Class-IL scenario, each label is always just for single observation
        #    (but if there is just 1 class per task, setting `label-per_batch` to ``True`` is more efficient)
        label_stream = TaskBasedStream(
            n_tasks=args.tasks, iters_per_task=args.iters if labels_per_batch else args.iters*args.batch,
            labels_per_task=classes_per_task if args.scenario=="class" else 1
        )
    elif args.stream=="random":
        label_stream = RandomStream(labels=args.tasks*classes_per_task if args.scenario=="class" else args.tasks)
    else:
        raise NotImplementedError("Stream type '{}' not currently implemented.".format(args.stream))

    # Load the data-sets
    (train_datasets, test_datasets), config, labels_per_task = prepare_datasets(
        name=args.experiment, n_labels=label_stream.n_labels, classes=(args.scenario=="class"),
        classes_per_task=classes_per_task, dir=args.d_dir, exception=(args.seed<10)
    )

    # Set up the data-stream to be presented to the network
    data_stream = DataStream(
        train_datasets, label_stream, batch=args.batch, return_task=(args.scenario=="task"),
        per_batch=labels_per_batch if (args.stream=="task-based") else args.labels_per_batch,
    )


    #-------------------------------------------------------------------------------------------------#

    #-----------------#
    #----- MODEL -----#
    #-----------------#

    # Define model
    # -how many units in the softmax output layer? (e.g., multi-headed or not?)
    softmax_classes = label_stream.n_labels if args.scenario=="class" else (
        classes_per_task if (args.scenario=="domain" or args.singlehead) else classes_per_task*label_stream.n_labels
    )
    # -set up model and move to correct device
    model = Classifier(
        image_size=config['size'], image_channels=config['channels'],
        classes=softmax_classes, fc_layers=args.fc_lay, fc_units=args.fc_units,
    ).to(device)
    # -if using a multi-headed output layer, set the "label-per-task"-list as attribute of the model
    model.multi_head = labels_per_task if (args.scenario=="task" and not args.singlehead) else None


    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- OPTIMIZER -----#
    #---------------------#

    # Define optimizer (only include parameters that "requires_grad")
    optim_list = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': args.lr}]
    if not args.cs:
        # Use the chosen 'standard' optimizer
        if args.optimizer == "sgd":
            model.optimizer = optim.SGD(optim_list, weight_decay=args.decay)
        elif args.optimizer=="adam":
            model.optimizer = optim.Adam(optim_list, betas=(0.9, 0.999), weight_decay=args.decay)
    else:
        # Use the "complex synapse"-version of the chosen optimizer
        if args.optimizer=="sgd":
            model.optimizer = cs.ComplexSynapse(optim_list, n_beakers=args.beakers, alpha=args.alpha, beta=args.beta,
                                                verbose=verbose)
        elif args.optimizer=="adam":
            model.optimizer = cs.AdamComplexSynapse(optim_list, betas=(0.9, 0.999), n_beakers=args.beakers,
                                                    alpha=args.alpha, beta=args.beta, verbose=verbose)


    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- REPORTING -----#
    #---------------------#

    # Get parameter-stamp (and print on screen)
    if verbose:
        print("\nParameter-stamp...")
    param_stamp = get_param_stamp(args, model.name, verbose=verbose)

    # Print some model-characteristics on the screen
    if verbose:
        utils.print_model_info(model, title="MAIN MODEL")

    # Prepare for keeping track of performance during training for storing & for later plotting in pdf
    metrics_dict = evaluate.initiate_metrics_dict(n_labels=label_stream.n_labels, classes=(args.scenario == "class"))

    # Prepare for plotting in visdom
    if args.visdom:
        env_name = "{exp}-{scenario}".format(exp=args.experiment, scenario=args.scenario)
        graph_name = "CS" if args.cs else "Normal"
        visdom = {'env': env_name, 'graph': graph_name}
    else:
        visdom = None


    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- CALLBACKS -----#
    #---------------------#

    # Callbacks for reporting on and visualizing loss
    loss_cbs = [
        cb.def_loss_cb(
            log=args.loss_log, visdom=visdom, tasks=label_stream.n_tasks,
            iters_per_task=args.iters if args.stream=="task-based" else None,
            task_name="Episode" if args.scenario=="class" else ("Task" if args.scenario=="task" else "Domain")
        )
    ]

    # Callbacks for reporting and visualizing accuracy
    eval_cbs = [
        cb.def_eval_cb(log=args.eval_log, test_datasets=test_datasets, scenario=args.scenario,
                       iters_per_task=args.iters if args.stream=="task-based" else None,
                       classes_per_task=classes_per_task, metrics_dict=metrics_dict, test_size=args.eval_n,
                       visdom=visdom, provide_task_info=(args.scenario=="task"))
    ]
    # -evaluate accuracy before any training
    for eval_cb in eval_cbs:
        if eval_cb is not None:
            eval_cb(model, 0)


    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- TRAINING -----#
    #--------------------#

    # Keep track of training-time
    if args.time:
        start = time.time()
    # Train model
    if verbose:
        print("\nTraining...")
    train_stream(model, data_stream, iters=args.iters*args.tasks if args.stream=="task-based" else args.iters,
                 eval_cbs=eval_cbs, loss_cbs=loss_cbs)
    # Get total training-time in seconds, and write to file and screen
    if args.time:
        training_time = time.time() - start
        time_file = open("{}/time-{}.txt".format(args.r_dir, param_stamp), 'w')
        time_file.write('{}\n'.format(training_time))
        time_file.close()
        if verbose:
            print("=> Total training time = {:.1f} seconds\n".format(training_time))


    #-------------------------------------------------------------------------------------------------#

    #----------------------#
    #----- EVALUATION -----#
    #----------------------#

    if verbose:
        print("\n\nEVALUATION RESULTS:")

    # Evaluate precision of final model on full test-set
    precs = [evaluate.validate(
        model, test_datasets[i], verbose=False, test_size=None, task=i+1 if args.scenario=="task" else None,
    ) for i in range(len(test_datasets))]
    average_precs = sum(precs) / len(test_datasets)
    # -print to screen
    if verbose:
        print("\n Precision on test-set:")
        for i in range(len(test_datasets)):
            print(" - {} {}: {:.4f}".format(args.scenario, i + 1, precs[i]))
        print('=> Average precision over all {} {}{}s: {:.4f}\n'.format(
            len(test_datasets), args.scenario, "e" if args.scenario=="class" else "", average_precs
        ))


    #-------------------------------------------------------------------------------------------------#

    #------------------#
    #----- OUTPUT -----#
    #------------------#

    # Average precision on full test set
    output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w')
    output_file.write('{}\n'.format(average_precs))
    output_file.close()
    # -metrics-dict
    file_name = "{}/dict-{}".format(args.r_dir, param_stamp)
    utils.save_object(metrics_dict, file_name)


    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- PLOTTING -----#
    #--------------------#

    # If requested, generate pdf
    if args.pdf:
        # -open pdf
        plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp)
        pp = visual_plt.open_pdf(plot_name)

        # -show metrics reflecting progression during training
        figure_list = []  #-> create list to store all figures to be plotted

        # -generate all figures (and store them in [figure_list])
        key = "class" if args.scenario=='class' else "task"
        plot_list = []
        for i in range(label_stream.n_labels):
            plot_list.append(metrics_dict["acc_per_{}".format(key)]["{}_{}".format(key, i+1)])
        figure = visual_plt.plot_lines(
            plot_list, x_axes=metrics_dict["iters"], xlabel="Iterations", ylabel="Accuracy",
            line_names=['{} {}'.format(args.scenario, i+1) for i in range(label_stream.n_labels)]
        )
        figure_list.append(figure)
        figure = visual_plt.plot_lines(
            [metrics_dict["ave_acc"]], x_axes=metrics_dict["iters"], xlabel="Iterations", ylabel="Accuracy",
            line_names=['average (over all {}{}s)'.format(args.scenario, "e" if args.scenario=="class" else "")],
            ylim=(0,1)
        )
        figure_list.append(figure)
        figure = visual_plt.plot_lines(
            [metrics_dict["ave_acc_so_far"]], x_axes=metrics_dict["iters"], xlabel="Iterations", ylabel="Accuracy",
            line_names=['average (over all {}{}s so far)'.format(args.scenario, "e" if args.scenario=="class" else "")],
            ylim=(0,1)
        )
        figure_list.append(figure)

        # -add figures to pdf (and close this pdf).
        for figure in figure_list:
            pp.savefig(figure)

        # -close pdf
        pp.close()

        # -print name of generated plot on screen
        if verbose:
            print("\nGenerated plot: {}\n".format(plot_name))
Пример #16
0
class ExtSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        self.ext_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size,
            args.ext_ff_size,
            args.ext_heads,
            args.ext_dropout,
            args.ext_layers,
        )

        if args.encoder == "baseline":
            bert_config = BertConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.ext_hidden_size,
                num_hidden_layers=args.ext_layers,
                num_attention_heads=args.ext_heads,
                intermediate_size=args.ext_ff_size,
            )
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if args.max_pos > 512:
            # 修改position_embeddings
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = \
                self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = \
                self.bert.model.embeddings.position_embeddings.weight.data[-1][None, :].repeat(
                args.max_pos - 512, 1
            )
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
            self.bert.model.embeddings.register_buffer(
                "position_ids",
                torch.arange(args.max_pos).expand(
                    (1, -1))  # position_ids不需要被更新,因此放入buffer
            )

        if checkpoint is not None:
            self.load_state_dict(checkpoint["model"], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)

    def forward(self, src, segs, clss, mask_src, mask_cls):
        # ----------------------------------------------------------------------------------------
        # 补全代码
        # 给 bert 增加一层 self.ext_layer,完成 BertSum 的前向传播过程
        # 需要注意的是,clss 是 <CLS> 所在位置,mask_cls 为真实的抽取出来句子的位置
        # 以上两个值需要仔细去看 data_loader.py 中的 Batch 类
        # ----------------------------------------------------------------------------------------
        # json文件中,将每一条数据的文本按标点符号拆成小句,如 “你好,更换全车油水,机油。变速箱油,刹车油,防冻液,清洗节气门,进气管,燃烧室,三元催化。”
        # 变为 [["你", "好"], ["更", "换", "全", "车", "油", "水"], ["机", "油"], ["变", "速", "箱", "油"], ["刹", "车", "油"], ["防", "冻", "液"], ["清", "洗", "节", "气", "门"], ["进", "气", "管"], ["燃", "烧", "室"], ["三", "元", "催", "化"]]
        # 然后在这些小句子的前后加上<CLS>和<SEP>标记,组成一个长句,而clss则是长句子中每个小句子的CLS的位置。
        # mask_cls 则是 clss 中抽取出来句子的位置
        # print(src[0])
        top_vec = self.bert(src, segs, mask_src)

        # 利用ext_layer对所有句子的cls位置的向量进行判断,选择适合作为摘要的句子

        # pytorch的索引中,[a, b]表示从a行中选出b列
        # 如果b是高维的话,那么a的第一个维度要和b的第一个维度对应,如b是[5,20]的向量,那么a就要是[5,1]
        sent_clss = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1),
                            clss]  # [batch_size, sub_sents, d_model]
        sents_vec = sent_clss * mask_cls[:, :, None].float(
        )  # 由于clss的填充使用的index是0,会与位置id 0冲突,因此必须进行mask,将填充的位置置为0向量
        sent_scores = self.ext_layer(sents_vec,
                                     mask_cls)  # [batch_size, sub_sents]

        return sent_scores, mask_cls
Пример #17
0
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.model_name, args.pretrained_name, args.temp_dir,
                         args.finetune_bert)

        self.ext_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size, args.ext_ff_size,
            args.ext_heads, args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            '''#without random initialization
            if args.model_name == 'bert':
                from transformers import BertModel,BertConfig
                bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
                self.bert.model = BertModel(bert_config)
            elif args.model_name == 'xlnet':
                from transformers import XLNetModel,XLNetConfig
                xlnet_config = XLNetConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
                self.bert.model = XLNetModel(xlnet_config)
            elif args.model_name == 'roberta':
                from transformers import RobertaModel, RobertaConfig
                roberta_config = RobertaConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
                self.bert.model = RobertaModel(roberta_config)
            elif args.model_name == 'bert_lstm':
                from transformers import BertLSTMModel,BertLSTMConfig
                bert_config = BertLSTMConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads,
                                             intermediate_size=args.ext_ff_size, lstm_layer=args.lstm_layer)
                self.bert.model = BertLSTMModel(bert_config)
            '''
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if args.model_name == 'bert':
            if (args.max_pos > args.max_model_pos):
                if args.bert_baseline == 1:
                    args.max_pos = args.max_model_pos
                    self.bert.model.config.max_position_embeddings = args.max_model_pos
                else:
                    my_pos_embeddings = nn.Embedding(
                        args.max_pos, self.bert.model.config.hidden_size)
                    offset = 0
                    while offset < args.max_pos:
                        if offset + args.max_model_pos < args.max_pos:
                            my_pos_embeddings.weight.data[offset:offset+args.max_model_pos] \
                                = self.bert.model.embeddings.position_embeddings.weight.data[:args.max_model_pos].contiguous()
                        else:
                            my_pos_embeddings.weight.data[offset:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:]\
                                .repeat(args.max_pos-offset,1)
                        offset += args.max_model_pos
                    self.bert.model.embeddings.position_embeddings = my_pos_embeddings
                    self.bert.model.config.max_position_embeddings = args.max_pos
        elif args.model_name == 'bert_lstm':
            self.bert.model.config.max_position_embeddings = args.max_model_pos
            #embedding:self.max_position_embeddings = config.max_position_embeddings
            #layer self.chunk_size = config.max_position_embeddings
            self.bert.model.embeddings.max_position_embeddings = args.max_model_pos
            for layer_i in range(len(self.bert.model.encoder.layer)):
                self.bert.model.encoder.layer[
                    layer_i].chunk_size = args.max_model_pos

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)
Пример #18
0
class Summarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None,
                 topic_num=10):
        super(Summarizer, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(
            'bert-base-uncased',
            do_lower_case=True,
            never_split=('[SEP]', '[CLS]', '[PAD]', '[unused0]', '[unused1]',
                         '[unused2]', '[UNK]'),
            no_word_piece=True)
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        self.memory = Memory(device, 1, self.bert.model.config.hidden_size)
        self.key_memory = Key_memory(device, 1,
                                     self.bert.model.config.hidden_size,
                                     args.dropout)
        self.topic_predictor = Topic_predictor(
            self.bert.model.config.hidden_size,
            device,
            topic_num,
            d_ex_type=args.d_ex_type)
        # self.topic_embedding = nn.Embedding(topic_num, self.bert.model.config.hidden_size)
        # todo transform to normal weight not embedding
        self.topic_embedding, self.topic_word, self.topic_word_emb = self.get_embedding(
            self.bert.model.embeddings)
        self.topic_embedding.requires_grad = True
        self.topic_word_emb.requires_grad = True
        self.topic_embedding = self.topic_embedding.to(device)
        self.topic_word_emb = self.topic_word_emb.to(device)
        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif (args.encoder == 'transformer'):
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers)
        elif (args.encoder == 'rnn'):
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)

    def load_cp(self, pt):
        self.load_state_dict(pt['model'], strict=True)

    def forward(self, x, segs, clss, mask, mask_cls, sentence_range=None):

        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        top_vec = self.bert(x, segs, mask)

        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        sents_vec = sents_vec * mask_cls[:, :, None].float()

        sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1)

        return sent_scores, mask_cls, top_vec

    # todo
    def get_embedding(self, embedding):
        lda_model_tfidf = models.ldamodel.LdaModel.load(
            '../models/sum_topic_10.model')
        topic = lda_model_tfidf.show_topics()
        embed_list = []
        topic_word = []
        topic_word_emb = []
        back = re.compile(r'["](.*?)["]')
        for meta_topic in topic:
            tmp_embedding = None
            text = meta_topic[1]
            word_list = re.findall(back, text)
            count = 0
            for i, word in enumerate(word_list):
                if word not in self.tokenizer.vocab:
                    continue
                topic_word.append(word)
                if tmp_embedding is None:
                    tmp_embedding = embedding(
                        torch.LongTensor([[self.tokenizer.vocab[word]]
                                          ])).squeeze().detach().numpy()
                    topic_word_emb.append(tmp_embedding)
                    count += 1
                else:
                    tmp_embedding += embedding(
                        torch.LongTensor([[self.tokenizer.vocab[word]]
                                          ])).squeeze().detach().numpy()
                    topic_word_emb.append(tmp_embedding)
                    count += 1
            tmp_embedding = tmp_embedding / count
            embed_list.append(tmp_embedding)
        return torch.FloatTensor(embed_list), topic_word, torch.FloatTensor(
            topic_word_emb)

    def sent_kmax(self, sents_vec, clss, clss_fw, device, doc_len):
        sent = []
        # print('sents_vec',sents_vec.size())
        for vec, cl_num, cl_fw in zip(sents_vec, clss, clss_fw):
            # print(vec.size())
            if cl_fw <= 0 or cl_fw == doc_len:
                sent.append(torch.zeros(vec.size(1)).unsqueeze(0).to(device))
            else:
                sent.append(vec.narrow(0, cl_num, cl_fw))
        # sent = [vec.narrow(cl_num, cl_fw) for vec, cl_num, cl_fw in zip(temp, clss, clss_fw)]
        # for i in sent:
        #     print(i.size())
        hidden_num = sents_vec.size(-1)

        # print(hidden_num)
        # print(sent[0].unsqueeze(0).unsqueeze(0).size())
        func_list = nn.ModuleList([nn.Conv2d(1, 1, (3, 9), padding=(2, 4))])
        sent = [
            F.relu(conv(x.unsqueeze(0).unsqueeze(0))) for conv in func_list
            for x in sent
        ]

        func_list = nn.ModuleList([nn.AvgPool2d(1, )])
        # print(sent[0].size())
        sent = [
            torch.mean(x.squeeze(0).squeeze(0).type(torch.float),
                       dim=0).unsqueeze(0) for x in sent
        ]
        # print(sent[0].size())
        # func_list = nn.ModuleList([nn.Conv2d(1, 1, (3, 5), padding=(2, 2))])
        # sent = [conv(x.unsqueeze(0).unsqueeze(0)) for conv in func_list for x in sent]
        #
        #
        # func_list = nn.ModuleList([kmax_pooling(0, 1)])
        # sent = [k_pool(x.squeeze(0).squeeze(0)) for k_pool in func_list for x in sent]
        # func_list = nn.ModuleList([kmax_pooling(0, 1)])
        # sent_pool = [k_pool(x) for k_pool in func_list for x in sent]

        # print(sent_pool[0].size())
        sent_pool = torch.cat(sent, dim=0).unsqueeze(1)
        # print('sent_pool:', sent_pool.size())
        return sent_pool