Beispiel #1
0
class ExtSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.bert_model_path, args.large, args.temp_dir,
                         args.finetune_bert)
        self.ext_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size, args.ext_ff_size,
            args.ext_heads, args.ext_dropout, args.ext_layers)

        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)

    def forward(self, src, segs, clss, mask_src, mask_cls):
        segs = (1 - segs % 2) * mask_src.long()
        top_vec = self.bert(src, segs, mask_src)
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
        sent_scores = self.ext_layer(sents_vec, mask_cls).squeeze(-1)
        return sent_scores, mask_cls
Beispiel #2
0
class ExtSummarizer(nn.Module):
    def __init__(self, args, device, ckpt):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.bert_path, args.finetune_bert)
        self.ext_layer = ExtTransformerEncoder()
        """
        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
            num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if(args.max_pos>512):
            my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        """
        if ckpt is not None:
            self.load_state_dict(
                ckpt['model'], strict=True
            )  # 注意这里strict用于检测model和ckpt里的keys是否严格一一对应,false则可以放缓
        else:
            if args.param_init != 0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            elif args.param_init_glorot:  # 即选用xavier均匀分布
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)  # 关键,一定记住

    """
    some points:
        self.tgt_bos = '[unused0]'
        self.tgt_eos = '[unused1]'
        self.tgt_sent_split = '[unused2]'
        to id: 1,2,3
        此外,抽取式摘要需要对原文的每个句子是否被选中做一个0,1的标签标记,称为src_sent_labels
        为什么clss的pad用-1?因为clss是cls的token的位置,不是token的值
    """

    def forward(self, x, segs, clss, mask_src, mask_cls):
        """
        inputs:
            x: input_ids [batch_size, seq_len] , len(max(seq_len)) < args.max_pos
            segs: token_type_ids [batch_size, seq_len]
            clss : cls_token_pos_ids [batch_size, clss_len], clss[:,] < args.max_pos
            mask_src : 0,1 [batch_size, seq_len]
            mask_cls: 0,1 [batch_size, clss_len]]
        """
        top_vec = self.bert(src, segs, mask_src)
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1),
                            clss]  # batch_size, clss_token, hidden_size
        sents_vec = sents_vec * mask_cls[:, :, None].float(
        )  # bsz, clstoken, hidden_size
        sent_scores = self.ext_layer(sents_vec, mask_cls).squeeze(-1)
        return sent_scores, mask_cls
        """
            如何得到[batch_size, clss_token, hidden_size]?当二维的tensor用作索引时,需要注意,第一维的tensor[a,b,。。。]中的每一个元素都要用来提取[:, seq_len, hiddensize],在此处,需要让第二维的tensor的num_dim = 第一维的tensor的num_dim.如果想在seq_len维度上抽取多个,则必须要batch_size的基础上unsqueeze(1)。
            mask[:,:,None]将原来的mask的最后一维度扩充,等同于unsqueeze(2),两个向量前面的维度相同时,等同于最后一个维度(只可能长度相同,或者其中一个长度为1)每个数字两两相乘
            squeeze 即判断最后一个维度是否为1,是的话去除
        """

        pass