Ejemplo n.º 1
0
    def __init__(self, params):
        super(BiencoderRanker, self).__init__()
        self.params = params
        # self.ctxt_bert=BertModel_new.from_pretrained(params['context_bert_model'],config=BertConfig_new.from_pretrained(params['context_bert_model']))
        self.ctxt_bert = load_model('./model/bert-large-uncased',
                                    params['context_bert_model'])
        self.cand_encoder = wiki_encoder.WikiEncoderModule(params)
        # self.load_cand_encoder_state()
        self.context_encoder = BertEncoder(
            self.ctxt_bert,
            params["out_dim"],
            layer_pulled=params["pull_from_layer"],
            add_linear=params["add_linear"],
        )
        self.mention_score = MentionScoresHead(bert_output_dim=768,
                                               score_method='qa_linear',
                                               max_mention_length=10)
        self.mention_loss = MentionLoss()

        # self.change_mention_embedding_dim = nn.Linear(768,1024)
        # self.change_mention_embedding_dim=nn.Sequential(
        #     nn.Linear(768, 768),
        #     nn.ReLU(),
        #     nn.Dropout(0.1),
        #     nn.Linear(768, 1024),
        # )
        #冻结参数
        for param in self.cand_encoder.parameters():
            param.requires_grad = True
Ejemplo n.º 2
0
 def __init__(self, params):
     super(WikiEncoderModule, self).__init__()
     cand_bert = BertModel.from_pretrained(params['bert_model'])
     self.cand_encoder = BertEncoder(
         cand_bert,
         params["out_dim"],
         layer_pulled=params["pull_from_layer"],
         add_linear=params["add_linear"],
     )
     self.config = cand_bert.config
Ejemplo n.º 3
0
    def __init__(self, params):
        super(BiencoderRanker, self).__init__()
        self.params = params
        self.ctxt_bert = BertModel_new.from_pretrained(
            params['context_bert_model'],
            config=BertConfig_new.from_pretrained(
                params['context_bert_model']))
        self.cand_encoder = wiki_encoder.WikiEncoderModule(params)
        self.load_cand_encoder_state()
        self.context_encoder = BertEncoder(
            self.ctxt_bert,
            params["out_dim"],
            layer_pulled=params["pull_from_layer"],
            add_linear=params["add_linear"],
        )
        self.mention_score = mention_detection.MentionScoresHead(
            bert_output_dim=768, score_method='qa_mlp', max_mention_length=10)
        self.mention_loss = mention_detection.MentionLoss()

        self.change_mention_embedding_dim = nn.Linear(768, 1024)
        #冻结参数
        for param in self.cand_encoder.parameters():
            param.requires_grad = False
Ejemplo n.º 4
0
class BiencoderRanker(nn.Module):
    def __init__(self, params):
        super(BiencoderRanker, self).__init__()
        self.params = params
        self.ctxt_bert = BertModel_new.from_pretrained(
            params['context_bert_model'],
            config=BertConfig_new.from_pretrained(
                params['context_bert_model']))
        self.cand_encoder = wiki_encoder.WikiEncoderModule(params)
        self.load_cand_encoder_state()
        self.context_encoder = BertEncoder(
            self.ctxt_bert,
            params["out_dim"],
            layer_pulled=params["pull_from_layer"],
            add_linear=params["add_linear"],
        )
        self.mention_score = mention_detection.MentionScoresHead(
            bert_output_dim=768, score_method='qa_mlp', max_mention_length=10)
        self.mention_loss = mention_detection.MentionLoss()

        self.change_mention_embedding_dim = nn.Linear(768, 1024)
        #冻结参数
        for param in self.cand_encoder.parameters():
            param.requires_grad = False

    def load_cand_encoder_state(self):

        self.cand_encoder.load_state_dict(
            torch.load(self.params['cand_encoder']))

    def get_raw_embedding_ctxt(self, input_ids, segment_type, input_mask):
        #返回每个token的嵌入
        raw_ctxt_embedding, _ = self.context_encoder.bert_model(
            input_ids, segment_type, input_mask)
        #DIM: (bs, seq_len,output_dim)
        return raw_ctxt_embedding

    def get_embedding_mention(self,
                              bert_output,
                              mention_bounds,
                              method='average_all'):
        #这里暂时先只实现'average_all'方法
        mention_bounds2 = mention_bounds.clone()
        for i in range(mention_bounds2.size(0)):
            for j in range(mention_bounds2.size(1)):
                if mention_bounds2[i, j, 1] < mention_bounds2[i, j, 0]:
                    t = mention_bounds2[i, j, 1].item()
                    mention_bounds2[i, j, 1] = mention_bounds2[i, j, 0]
                    mention_bounds2[i, j, 0] = t
        '''
        example:
            bert_output=torch.ones(1,6,8)
            mention_bounds=torch.tensor([[[1,2],[3,4],[4,2]]])
            
        '''
        start_pos = mention_bounds2[:, :, 0]
        end_pos = mention_bounds2[:, :, 1]
        mention_embedding = torch.zeros(mention_bounds2.size(0),
                                        mention_bounds2.size(1),
                                        bert_output.size(2),
                                        dtype=bert_output.dtype,
                                        device=bert_output.device)
        for i in range(mention_bounds2.size(0)):
            for j in range(mention_bounds2.size(1)):
                cur_star = start_pos[i, j]
                cur_end = end_pos[i, j]
                mention_embedding[i, j, :] = bert_output[i, cur_star:cur_end +
                                                         1, :].mean(dim=0)

        del mention_bounds2
        return mention_embedding

    def get_embedding_cand(self, input_ids, segment_type, input_mask):
        #返回candidate的整句嵌入
        #DIM: (bs, output_dim)
        return self.cand_encoder(input_ids, segment_type, input_mask)

    def get_mention_scores(self, bert_output, input_mask):

        #return mention_scores, mention_bounds
        return self.mention_score(bert_output, input_mask.bool())

    def forward_ctxt(self, input_ids, segment_type, input_mask, EL=False):

        bert_output_ctxt = self.get_raw_embedding_ctxt(input_ids, segment_type,
                                                       input_mask)
        mention_score, mention_bounds = self.get_mention_scores(
            bert_output_ctxt, input_mask.bool())
        if EL:
            mention_embedding = self.get_embedding_mention(
                bert_output_ctxt, mention_bounds, method='average_all')
        else:
            mention_embedding = None
        return bert_output_ctxt, mention_score, mention_bounds, mention_embedding

    def prune_mention(self, mention_scores, mention_bounds, mention_embedding,
                      gold_mention_bounds, gold_mention_bounds_mask,
                      gold_entity_local_id):
        '''
        example:
        这里,我只筛掉scores==-float('inf'), end_pos<start_pos的mention,没有进一步根据threshold score筛选,先尽量简单化以跑通逻辑
            mention_embedding=torch.ones(2,7,10)
            mention_scores=torch.tensor([[2,4,0.5,3,-float('inf'),8,-float('inf')],[2,4,0.5,3,9,8,-float('inf')]])
            mention_bounds=torch.tensor([[[1,2],[3,1],[2,3],[2,5],[3,4],[4,1],[4,5]],[[1,2],[3,1],[2,3],[2,5],[3,4],[4,1],[4,5]]])


            gold_mention_bounds=torch.tensor([[[2,4],[-1,-1],[-1,-1]],[[2,4],[-1,-1],[-1,-1]]])
            gold_mention_bounds_mask=torch.tensor([[1,0,0],[1,0,0]]).bool()

            gold_entity_local_id=torch.tensor([[200,-1,-1],[200,-1,-1]])
        '''

        gold_mention_bounds[:, :, 1] -= 1
        gold_mention_bounds[~(gold_mention_bounds_mask.bool()
                              )] = -1  # ensure don't select masked to score

        gold_entity_local_id_extend = torch.zeros_like(
            mention_scores, device=gold_entity_local_id.device)
        for i in range(mention_bounds.size(0)):
            for j in range(mention_bounds.size(1)):
                link = False
                for m in range(gold_mention_bounds.size(1)):
                    if mention_bounds[i, j, 0] == gold_mention_bounds[
                            i, m, 0] and mention_bounds[
                                i, j, 1] == gold_mention_bounds[i, m, 1]:
                        gold_entity_local_id_extend[
                            i, j] = gold_entity_local_id[i, m]
                        link = True
                if not link:
                    gold_entity_local_id_extend[i, j] = -1

        mask = (mention_scores != -float('inf')) & (mention_bounds[:, :, 1] >=
                                                    mention_bounds[:, :, 0])
        #DIM:(all_pred_mention_in_batch, 1)
        mention_scores = mention_scores[mask].view(-1, 1)
        #DIM:(all_pred_mention_in_batch, 2)
        mention_bounds = mention_bounds[mask]
        #DIM: (all_pred_mention_in_batch, output_dim)
        mention_embedding = mention_embedding[mask]
        #DIM: (all_pred_mention_in_batch, 1)
        gold_entity_local_id_extend = gold_entity_local_id_extend[mask].view(
            -1, 1)

        return mention_scores, mention_bounds, mention_embedding, gold_entity_local_id_extend

    def forward_cand(self,
                     input_ids,
                     segment_type,
                     input_mask,
                     pre_computed_cand_embedding=None):
        #pre_computed_cand_embedding;在本代码中为用blink的wiki_encoder预先计算的embedding

        if pre_computed_cand_embedding == None:
            embedding = pre_computed_cand_embedding
        else:
            embedding = self.get_embedding_cand(input_ids, segment_type,
                                                input_mask)
        return embedding

    def forward(self,
                input_ids_ctxt,
                segment_type_ctxt,
                input_mask_ctxt,
                gold_mention_bounds,
                gold_mention_bounds_mask,
                input_ids_cand=None,
                segment_type_cand=None,
                input_mask_cand=None,
                return_forward_ctxt=True,
                pre_trained_cand=None,
                all_mention_embedding=None,
                candidate_label=None,
                EL=False):
        if return_forward_ctxt:
            bert_output_ctxt, mention_scores, mention_bounds, mention_embedding = self.forward_ctxt(
                input_ids_ctxt, segment_type_ctxt, input_mask_ctxt, EL)

            # compute Mention Loss

            ment_loss = self.mention_loss(gold_mention_bounds,
                                          gold_mention_bounds_mask.bool(),
                                          mention_scores, mention_bounds)
            self.ment_loss = ment_loss
            return bert_output_ctxt, mention_scores, mention_bounds, mention_embedding, ment_loss

        if pre_trained_cand is None:
            cand_embedding = self.forward_cand(input_ids_cand,
                                               segment_type_cand,
                                               input_mask_cand)
        else:
            cand_embedding = pre_trained_cand

        #计算candidate entities的分数
        scores = self.score_cand(all_mention_embedding, cand_embedding)

        #compute EL Loss
        #DIM (all_predmention*top_k,1)
        scores = scores.view(-1, 1)
        '''
            candidate_label=torch.tensor([[0,1],[1,0],[0,0]])
        '''
        candidate_label = candidate_label.view(-1, 1)
        loss_fct = nn.BCEWithLogitsLoss(reduction="mean")
        el_loss = loss_fct(scores.float(), candidate_label.float())
        return el_loss + self.ment_loss

    def score_cand(self, mention_embedding, cand_embedding):
        '''
        mention_embedding=torch.tensor([[1,2,3,4,5,6],[7,8,9,10,11,12],[13,14,15,16,17,18]])
        cand_embedding=torch.tensor([[[1,1,1,1,1,1],[1,1,1,1,1,1]],[[1,1,1,1,1,1],[1,1,1,1,1,1]],[[1,1,1,1,1,1],[1,1,1,1,1,1]]])
        scores=torch.bmm(cand_embedding, mention_embedding.unsqueeze(2))

        :param mention_embedding:
        :param cand_embedding:
        :return:
        '''
        #DIM (all_predmention, top_k)
        scores = torch.bmm(cand_embedding, mention_embedding.unsqueeze(2))

        return scores