예제 #1
0
class HybridSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint = None, checkpoint_ext = None, checkpoint_abs = None):
        super(HybridSummarizer, self).__init__()
        self.args = args
        self.args
        self.device = device
        self.extractor = ExtSummarizer(args, device, checkpoint_ext)
        # self.abstractor = PGTransformers(modules, consts, options)
        self.abstractor = AbsSummarizer(args, device, checkpoint_abs)
        self.context_attn = MultiHeadedAttention(head_count = self.args.dec_heads, model_dim =self.args.dec_hidden_size, dropout=self.args.dec_dropout, need_distribution = True)

        self.v = nn.Parameter(torch.Tensor(1, self.args.dec_hidden_size * 3))
        self.bv = nn.Parameter(torch.Tensor(1))
        self.attn_lin = nn.Linear(self.args.dec_hidden_size, self.args.dec_hidden_size)
        if self.args.hybrid_loss:
            self.ext_loss_fun = torch.nn.BCELoss(reduction='none')
        if self.args.hybrid_connector:
            self.p_sen = nn.Linear(self.args.dec_hidden_size, 1)

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
            print("checkpoint is loading !")
        else:
            self.attn_lin.weight.data.normal_(mean=0.0, std=0.02)

            nn.init.xavier_uniform_(self.v)
            nn.init.constant_(self.bv, 0)
            if self.args.hybrid_connector:
                for module in self.p_sen.modules():
                    # print(each)
                    if isinstance(module, (nn.Linear, nn.Embedding)):
                        module.weight.data.normal_(mean=0.0, std=0.02)
                    elif isinstance(module, nn.LayerNorm):
                        module.bias.data.zero_()
                        module.weight.data.fill_(1.0)
                    if isinstance(module, nn.Linear) and module.bias is not None:
                        module.bias.data.zero_()

            for module in self.context_attn.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
        self.to(device)

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, labels = None):

        if labels is not None and self.args.oracle:
            ext_scores = ((labels.float(), + 0.1) / 1.3) * mask_cls.float()
        else:
            if labels is None:
                with torch.no_grad():
                    ext_scores, _, sent_vec = self.extractor(src, segs, clss, mask_src, mask_cls)
            else:
                ext_scores, _, sent_vec = self.extractor(src, segs, clss, mask_src, mask_cls)
                ext_loss = self.ext_loss_fun(ext_scores, labels.float())
                ext_loss = ext_loss * mask_cls.float()

        # [batchsize * (tgt_len - 1) * hidden_size]
        # projected into the probability distribution of vocab_size from hidden state.
        decoder_outputs, encoder_state, y_emb = self.abstractor(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
        src_pad_mask = (1 - mask_src).unsqueeze(1).repeat(1, tgt.size(1) - 1, 1)
        context_vector, attn_dist = self.context_attn(encoder_state, encoder_state, decoder_outputs, mask=src_pad_mask, type="context")
        if self.args.hybrid_connector:
            sorted_scores, sorted_scores_idx = torch.sort(ext_scores, dim=1, descending=True)
            # for top-3 select num.
            select_num = min(3, mask_cls.size(1))
            # 每个句子单独算一个值出来。
            # 这里有一堆东西,但是是把选出的前三个句子和对应评分加权,方便下边求和。
            selected_sent_vec = tuple([(sorted_scores[i][:select_num].unsqueeze(0).transpose(0,1) * sent_vec[i,tuple(sorted_scores_idx[i][:select_num])]).unsqueeze(0) for i, each in enumerate(sorted_scores_idx)])
            selected_sent_vec = torch.cat(selected_sent_vec, dim=0)
            selected_sent_vec = selected_sent_vec.sum(dim=1)
            E_sel = self.p_sen(selected_sent_vec)
            ext_scores = ext_scores * E_sel

        g = torch.sigmoid(F.linear(torch.cat([decoder_outputs, y_emb, context_vector], -1), self.v, self.bv))
        xids = src.unsqueeze(0).repeat(tgt.size(1) - 1, 1, 1).transpose(0,1)
        xids = xids * mask_tgt.unsqueeze(2)[:,:-1,:].long()

        # mask characters such as CLS um
        len0 = src.size(1)
        len0 = torch.Tensor([[len0]]).repeat(src.size(0), 1).long().to('cuda')
        clss_up = torch.cat((clss, len0), dim=1)
        sent_len = (clss_up[:, 1:] - clss) * mask_cls.long()
        for i in range(mask_cls.size(0)):
            for j in range(mask_cls.size(1)):
                if sent_len[i][j] < 0:
                    sent_len[i][j] += src.size(1)
        ext_scores_0 = ext_scores.unsqueeze(1).transpose(1,2).repeat(1,1, src.size(1))
        for i in range(clss.size(0)):
            tmp_vec = ext_scores_0[i, 0, :sent_len[i][0].int()]

            for j in range(1, clss.size(1)):
                tmp_vec = torch.cat((tmp_vec, ext_scores_0[i, j, :sent_len[i][j].int()]), dim=0)
            if i == 0:
                ext_scores_new = tmp_vec.unsqueeze(0)
            else:
                ext_scores_new = torch.cat((ext_scores_new, tmp_vec.unsqueeze(0)), dim=0)
        ext_scores_new = ext_scores_new * mask_src.float()
        attn_dist = attn_dist * (ext_scores_new + 1).unsqueeze(1)
        # Weighted sum formula.
        attn_dist = attn_dist / attn_dist.sum(dim=2).unsqueeze(2)

        ext_dist = Variable(torch.zeros(tgt.size(0), tgt.size(1) - 1, self.abstractor.bert.model.config.vocab_size).to(self.device))
        ext_vocab_prob = ext_dist.scatter_add(2, xids, (1 - g) * mask_tgt.unsqueeze(2)[:,:-1,:].float() * attn_dist) * mask_tgt.unsqueeze(2)[:,:-1,:].float()
        if self.args.hybrid_loss:
            return decoder_outputs, None, (ext_vocab_prob, g, ext_loss)
        else:
            return decoder_outputs, None, (ext_vocab_prob, g)
예제 #2
0
class HybridSummarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 checkpoint_ext=None,
                 checkpoint_abs=None):
        super(HybridSummarizer, self).__init__()
        self.args = args
        self.args
        self.device = device

        self.extractor = ExtSummarizer(args, device, checkpoint_ext)
        self.abstractor = AbsSummarizer(args, device, checkpoint_abs)
        self.context_attn = MultiHeadedAttention(
            head_count=self.args.dec_heads,
            model_dim=self.args.dec_hidden_size,
            dropout=self.args.dec_dropout,
            need_distribution=True)
        self.v = nn.Parameter(torch.Tensor(1, self.args.dec_hidden_size * 3))
        self.bv = nn.Parameter(torch.Tensor(1))
        self.attn_lin = nn.Linear(self.args.dec_hidden_size,
                                  self.args.dec_hidden_size)
        if self.args.hybrid_loss:
            self.ext_loss_fun = torch.nn.BCELoss(reduction='none')
        if self.args.hybrid_connector:
            self.p_sen = nn.Linear(self.args.dec_hidden_size, 1)

        # When Bert is testing, he loads directly.
        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
            print("checkpoint loaded!")
        else:
            self.attn_lin.weight.data.normal_(mean=0.0, std=0.02)
            nn.init.xavier_uniform_(self.v)
            nn.init.constant_(self.bv, 0)
            if self.args.hybrid_connector:
                for module in self.p_sen.modules():
                    # print(each)
                    if isinstance(module, (nn.Linear, nn.Embedding)):
                        module.weight.data.normal_(mean=0.0, std=0.02)
                    elif isinstance(module, nn.LayerNorm):
                        module.bias.data.zero_()
                        module.weight.data.fill_(1.0)
                    if isinstance(module,
                                  nn.Linear) and module.bias is not None:
                        module.bias.data.zero_()

            for module in self.context_attn.modules():
                # print(each)
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
        self.to(device)

    def forward(self,
                src,
                tgt,
                segs,
                clss,
                mask_src,
                mask_tgt,
                mask_cls,
                labels=None):

        if labels is not None and self.args.oracle:
            ext_scores = ((labels.float(), +0.1) / 1.3) * mask_cls.float()
        else:
            # w
            if labels is None:
                with torch.no_grad():
                    ext_scores, _, sent_vec = self.extractor(
                        src, segs, clss, mask_src, mask_cls)
            else:
                ext_scores, _, sent_vec = self.extractor(
                    src, segs, clss, mask_src, mask_cls)
                ext_loss = self.ext_loss_fun(ext_scores, labels.float())
                ext_loss = ext_loss * mask_cls.float()
        decoder_outputs, encoder_state, y_emb = self.abstractor(
            src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
        src_pad_mask = (1 - mask_src).unsqueeze(1).repeat(
            1,
            tgt.size(1) - 1, 1)
        context_vector, attn_dist = self.context_attn(encoder_state,
                                                      encoder_state,
                                                      decoder_outputs,
                                                      mask=src_pad_mask,
                                                      type="context")
        if self.args.hybrid_connector:
            select_num = min(3, mask_cls.size(1))
            selected_sent_vec = tuple([
                (sorted_scores[i][:select_num].unsqueeze(0).transpose(0, 1) *
                 sent_vec[i, tuple(sorted_scores_idx[i][:select_num])]
                 ).unsqueeze(0) for i, each in enumerate(sorted_scores_idx)
            ])

            selected_sent_vec = torch.cat(selected_sent_vec, dim=0)
            selected_sent_vec = selected_sent_vec.sum(dim=1)
            E_sel = self.p_sen(selected_sent_vec)
            ext_scores = ext_scores * E_sel

        if torch.isnan(decoder_outputs[0][0][0]):
            print("ops, decoder_outputs!")
            print("src = ", src.size())
            print(src)
            print("tgt = ", tgt.size())
            print(tgt)
            # # # segs是每个词属于哪句话
            print("segs = ", segs.size())
            print(segs)
            # # clss 是每个句子的起点位置
            print("clss = ", clss.size())
            print(clss)
            print("mask_src = ", mask_src.size())
            print(mask_src)
            print("mask_cls = ", mask_cls.size())
            print(mask_cls)
            print("decoder_outputs ", decoder_outputs.size())
            print(decoder_outputs)
            print("y_emb ", y_emb)
            print(y_emb)
            print("context_vector ", context_vector.size())
            print(context_vector)
            exit()

        if torch.isnan(y_emb[0][0][0]):
            print("ops, yemb!")
            print("src = ", src.size())
            print(src)
            print("tgt = ", tgt.size())
            print(tgt)
            # # # segs是每个词属于哪句话
            print("segs = ", segs.size())
            print(segs)
            # # clss 是每个句子的起点位置
            print("clss = ", clss.size())
            print(clss)
            print("mask_src = ", mask_src.size())
            print(mask_src)
            print("mask_cls = ", mask_cls.size())
            print(mask_cls)
            print("decoder_outputs ", decoder_outputs.size())
            print(decoder_outputs)
            print("y_emb ", y_emb)
            print(y_emb)
            print("context_vector ", context_vector.size())
            print(context_vector)
            exit()

        if torch.isnan(context_vector[0][0][0]):
            print("ops, context_vector!")
            print("src = ", src.size())
            print(src)
            print("tgt = ", tgt.size())
            print(tgt)
            # # # segs是每个词属于哪句话
            print("segs = ", segs.size())
            print(segs)
            # # clss 是每个句子的起点位置
            print("clss = ", clss.size())
            print(clss)
            print("mask_src = ", mask_src.size())
            print(mask_src)
            print("mask_cls = ", mask_cls.size())
            print(mask_cls)
            print("decoder_outputs ", decoder_outputs.size())
            print(decoder_outputs)
            print("y_emb ", y_emb)
            print(y_emb)
            print("context_vector ", context_vector.size())
            print(context_vector)
            exit()

        g = torch.sigmoid(
            F.linear(torch.cat([decoder_outputs, y_emb, context_vector], -1),
                     self.v, self.bv))
        if torch.isnan(g[0][0]):
            print("ops!, g")
            print("src = ", src.size())
            print(src)
            print("tgt = ", tgt.size())
            print(tgt)
            # # # segs是每个词属于哪句话
            print("segs = ", segs.size())
            print(segs)
            # # clss 是每个句子的起点位置
            print("clss = ", clss.size())
            print(clss)
            print("mask_src = ", mask_src.size())
            print(mask_src)
            print("mask_cls = ", mask_cls.size())
            print(mask_cls)
            print("decoder_outputs ", decoder_outputs.size())
            print(decoder_outputs)
            print("y_emb ", y_emb)
            print(y_emb)
            print("context_vector ", context_vector.size())
            print(context_vector)
            print("g ", g.size())
            print(g)
            exit()

        xids = src.unsqueeze(0).repeat(tgt.size(1) - 1, 1, 1).transpose(0, 1)
        xids = xids * mask_tgt.unsqueeze(2)[:, :-1, :].long()
        len0 = src.size(1)
        len0 = torch.Tensor([[len0]]).repeat(src.size(0), 1).long().to('cuda')
        clss_up = torch.cat((clss, len0), dim=1)
        sent_len = (clss_up[:, 1:] - clss) * mask_cls.long()
        for i in range(mask_cls.size(0)):
            for j in range(mask_cls.size(1)):
                if sent_len[i][j] < 0:
                    sent_len[i][j] += src.size(1)
        ext_scores_0 = ext_scores.unsqueeze(1).transpose(1, 2).repeat(
            1, 1, src.size(1))
        for i in range(clss.size(0)):
            tmp_vec = ext_scores_0[i, 0, :sent_len[i][0].int()]

            for j in range(1, clss.size(1)):
                tmp_vec = torch.cat(
                    (tmp_vec, ext_scores_0[i, j, :sent_len[i][j].int()]),
                    dim=0)
            if i == 0:
                ext_scores_new = tmp_vec.unsqueeze(0)
            else:
                ext_scores_new = torch.cat(
                    (ext_scores_new, tmp_vec.unsqueeze(0)), dim=0)
        ext_scores_new = ext_scores_new * mask_src.float()
        attn_dist = attn_dist * (ext_scores_new + 1).unsqueeze(1)
        attn_dist = attn_dist / attn_dist.sum(dim=2).unsqueeze(2)
        ext_dist = Variable(
            torch.zeros(tgt.size(0),
                        tgt.size(1) - 1,
                        self.abstractor.bert.model.config.vocab_size).to(
                            self.device))
        ext_vocab_prob = ext_dist.scatter_add(
            2, xids, (1 - g) * mask_tgt.unsqueeze(2)[:, :-1, :].float() *
            attn_dist) * mask_tgt.unsqueeze(2)[:, :-1, :].float()

        if self.args.hybrid_loss:
            return decoder_outputs, None, (ext_vocab_prob, g, ext_loss)
        else:
            return decoder_outputs, None, (ext_vocab_prob, g)