示例#1
0
class BertEmbedModel(nn.Module):
    """This class acts as an embeddding layer with bert model
    """
    def __init__(self, cfg, vocab):
        """This function constructs `BertEmbedModel` components and
        sets `BertEmbedModel` parameters

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
            vocab {Vocabulary} -- vocabulary
        """

        super().__init__()
        self.activation = nn.GELU()
        self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name,
                                        trainable=cfg.fine_tune,
                                        output_size=cfg.bert_output_size,
                                        activation=self.activation,
                                        dropout=cfg.bert_dropout)
        self.encoder_output_size = self.bert_encoder.get_output_dims()

    def forward(self, batch_inputs):
        """This function propagetes forwardly

        Arguments:
            batch_inputs {dict} -- batch input data
        """

        if 'wordpiece_segment_ids' in batch_inputs:
            batch_seq_bert_encoder_repr, batch_cls_repr = self.bert_encoder(
                batch_inputs['wordpiece_tokens'],
                batch_inputs['wordpiece_segment_ids'])
        else:
            batch_seq_bert_encoder_repr, batch_cls_repr = self.bert_encoder(
                batch_inputs['wordpiece_tokens'])

        batch_seq_tokens_encoder_repr = batched_index_select(
            batch_seq_bert_encoder_repr,
            batch_inputs['wordpiece_tokens_index'])

        batch_inputs['seq_encoder_reprs'] = batch_seq_tokens_encoder_repr
        batch_inputs['seq_cls_repr'] = batch_cls_repr

    def get_hidden_size(self):
        """This function returns embedding dimensions
        
        Returns:
            int -- embedding dimensitons
        """

        return self.encoder_output_size
示例#2
0
class PretrainedSpanEncoder(nn.Module):
    """PretrainedSpanEncoder encodes span into vector.
    """
    def __init__(self, cfg, momentum=False):
        """This funciton constructs `PretrainedSpanEncoder` components

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models

        Keyword Arguments:
            momentum {bool} -- whether this encoder is momentum encoder (default: {False})
        """

        super().__init__()
        self.ent_output_size = cfg.ent_output_size
        self.span_batch_size = cfg.span_batch_size
        self.position_embedding_dims = cfg.position_embedding_dims
        self.att_size = cfg.att_size
        self.momentum = momentum
        self.activation = nn.GELU()
        self.device = cfg.device

        self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name,
                                        trainable=cfg.fine_tune,
                                        output_size=cfg.bert_output_size,
                                        activation=self.activation)

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=self.bert_encoder.get_output_dims(),
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(input_size=self.entity_span_extractor.get_output_dims(),
                                         output_size=self.ent_output_size,
                                         activation=self.activation,
                                         dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.entity_span_mlp = BertLinear(input_size=self.ent_output_size,
                                          output_size=self.ent_output_size,
                                          activation=self.activation,
                                          dropout=cfg.dropout)
        self.entity_span_decoder = VanillaSoftmaxDecoder(hidden_size=self.ent_output_size,
                                                         label_size=6)

        self.global_position_embedding = nn.Embedding(150, 200)
        self.global_position_embedding.weight.data.normal_(mean=0.0, std=0.02)
        self.masked_token_mlp = BertLinear(input_size=self.bert_encoder.get_output_dims() + 200,
                                           output_size=self.bert_encoder.get_output_dims(),
                                           activation=self.activation,
                                           dropout=cfg.dropout)
        self.masked_token_decoder = nn.Linear(self.bert_encoder.get_output_dims(),
                                              28996,
                                              bias=False)
        self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02)
        self.masked_token_decoder_bias = nn.Parameter(torch.zeros(28996))

        self.position_embedding = nn.Embedding(7, self.position_embedding_dims)
        self.position_embedding.weight.data.normal_(mean=0.0, std=0.02)
        self.attention_encoder = PosAwareAttEncoder(self.ent_output_size,
                                                    self.bert_encoder.get_output_dims(),
                                                    2 * self.position_embedding_dims,
                                                    self.att_size,
                                                    activation=self.activation,
                                                    dropout=cfg.dropout)

        self.mlp_head1 = BertLinear(self.ent_output_size,
                                    self.bert_encoder.get_output_dims(),
                                    activation=self.activation,
                                    dropout=cfg.dropout)
        self.mlp_head2 = BertLinear(self.bert_encoder.get_output_dims(),
                                    self.bert_encoder.get_output_dims(),
                                    activation=self.activation,
                                    dropout=cfg.dropout)

        self.masked_token_loss = nn.CrossEntropyLoss()

    def forward(self, batch_inputs):
        """This function propagetes forwardly

        Arguments:
            batch_inputs {dict} -- batch inputs
        
        Returns:
            dict -- results
        """

        batch_seq_wordpiece_tokens_repr, batch_seq_cls_repr = self.bert_encoder(
            batch_inputs['wordpiece_tokens'])
        batch_seq_tokens_repr = batched_index_select(batch_seq_wordpiece_tokens_repr,
                                                     batch_inputs['wordpiece_tokens_index'])

        results = {}

        entity_feature = self.entity_span_extractor(batch_seq_tokens_repr,
                                                    batch_inputs['span_mention'])
        entity_feature = self.ent2hidden(entity_feature)

        subj_pos = torch.LongTensor([-1, 0, 1, 2, 3]) + 3
        obj_pos = torch.LongTensor([-3, -2, -1, 0, 1]) + 3

        if self.device > -1:
            subj_pos = subj_pos.cuda(device=self.device, non_blocking=True)
            obj_pos = obj_pos.cuda(device=self.device, non_blocking=True)

        subj_pos_emb = self.position_embedding(subj_pos)
        obj_pos_emb = self.position_embedding(obj_pos)
        pos_emb = torch.cat([subj_pos_emb, obj_pos_emb], dim=1).unsqueeze(0).repeat(
            batch_inputs['wordpiece_tokens_index'].size()[0], 1, 1)

        span_mention_attention_repr = self.attention_encoder(inputs=entity_feature,
                                                             query=batch_seq_cls_repr,
                                                             feature=pos_emb)
        results['span_mention_repr'] = self.mlp_head2(self.mlp_head1(span_mention_attention_repr))

        if self.momentum:
            return results

        zero_loss = torch.Tensor([0])
        zero_loss.requires_grad = True
        if self.device > -1:
            zero_loss = zero_loss.cuda(device=self.device, non_blocking=True)

        if sum([len(masked_index) for masked_index in batch_inputs['masked_index']]) == 0:
            results['masked_token_loss'] = zero_loss
        else:
            masked_wordpiece_tokens_repr = []
            all_masked_label = []
            for masked_index, masked_position, masked_label, seq_wordpiece_tokens_repr in zip(
                    batch_inputs['masked_index'], batch_inputs['masked_position'],
                    batch_inputs['masked_label'], batch_seq_wordpiece_tokens_repr):
                masked_index_tensor = torch.LongTensor(masked_index)
                masked_position_tensor = torch.LongTensor(masked_position)

                if self.device > -1:
                    masked_index_tensor = masked_index_tensor.cuda(device=self.device,
                                                                   non_blocking=True)
                    masked_position_tensor = masked_position_tensor.cuda(device=self.device,
                                                                         non_blocking=True)

                masked_wordpiece_tokens_repr.append(
                    torch.cat([
                        seq_wordpiece_tokens_repr[masked_index_tensor],
                        self.global_position_embedding(masked_position_tensor)
                    ],
                              dim=1))
                all_masked_label.extend(masked_label)

            masked_wordpiece_tokens_input = torch.cat(masked_wordpiece_tokens_repr, dim=0)
            masked_wordpiece_tokens_output = self.masked_token_decoder(
                self.masked_token_mlp(
                    masked_wordpiece_tokens_input)) + self.masked_token_decoder_bias

            all_masked_label_tensor = torch.LongTensor(all_masked_label)
            if self.device > -1:
                all_masked_label_tensor = all_masked_label_tensor.cuda(device=self.device,
                                                                       non_blocking=True)
            results['masked_token_loss'] = self.masked_token_loss(masked_wordpiece_tokens_output,
                                                                  all_masked_label_tensor)

        all_spans = []
        all_spans_label = []
        all_seq_tokens_reprs = []
        for spans, spans_label, seq_tokens_repr in zip(batch_inputs['spans'],
                                                       batch_inputs['spans_label'],
                                                       batch_seq_tokens_repr):
            all_spans.extend(spans)
            all_spans_label.extend(spans_label)
            all_seq_tokens_reprs.extend(seq_tokens_repr for _ in range(len(spans)))

        assert len(all_spans) == len(all_seq_tokens_reprs) and len(all_spans) == len(
            all_spans_label)

        if len(all_spans) == 0:
            results['span_loss'] = zero_loss
        else:
            if self.span_batch_size > 0:
                all_span_loss = []
                for idx in range(0, len(all_spans), self.span_batch_size):
                    batch_ents_tensor = torch.LongTensor(
                        all_spans[idx:idx + self.span_batch_size]).unsqueeze(1)
                    if self.device > -1:
                        batch_ents_tensor = batch_ents_tensor.cuda(device=self.device,
                                                                   non_blocking=True)

                    batch_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs[idx:idx +
                                                                              self.span_batch_size])

                    batch_spans_feature = self.ent2hidden(
                        self.entity_span_extractor(batch_seq_tokens_reprs,
                                                   batch_ents_tensor).squeeze(1))

                    batch_spans_label = torch.LongTensor(all_spans_label[idx:idx +
                                                                         self.span_batch_size])
                    if self.device > -1:
                        batch_spans_label = batch_spans_label.cuda(device=self.device,
                                                                   non_blocking=True)

                    span_outputs = self.entity_span_decoder(
                        self.entity_span_mlp(batch_spans_feature), batch_spans_label)

                    all_span_loss.append(span_outputs['loss'])
                results['span_loss'] = sum(all_span_loss) / len(all_span_loss)
            else:
                all_spans_tensor = torch.LongTensor(all_spans).unsqueeze(1)
                if self.device > -1:
                    all_spans_tensor = all_spans_tensor.cuda(device=self.device, non_blocking=True)
                all_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs)
                all_spans_feature = self.entity_span_extractor(all_seq_tokens_reprs,
                                                               all_spans_tensor).squeeze(1)

                all_spans_feature = self.ent2hidden(all_spans_feature)

                all_spans_label = torch.LongTensor(all_spans_label)
                if self.device > -1:
                    all_spans_label = all_spans_label.cuda(device=self.device, non_blocking=True)

                entity_typing_outputs = self.entity_span_decoder(
                    self.entity_span_mlp(all_spans_feature), all_spans_label)

                results['span_loss'] = entity_typing_outputs['loss']

        return results
示例#3
0
class JointREPretrainedModel(nn.Module):
    """This class contains entity typing, masked token prediction, entity mention permutation prediction, confused entity mention context rank loss, four pretrained tasks in total.
    """
    def __init__(self, cfg):
        """This function decides `JointREPretrainedModel` components

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
        """

        super().__init__()
        self.ent_output_size = cfg.ent_output_size
        self.context_output_size = cfg.context_output_size
        self.output_size = cfg.ent_mention_output_size
        self.ent_batch_size = cfg.ent_batch_size
        self.permutation_batch_size = cfg.permutation_batch_size
        self.permutation_samples_num = cfg.permutation_samples_num
        self.confused_batch_size = cfg.confused_batch_size
        self.confused_samples_num = cfg.confused_samples_num
        self.activation = gelu
        self.device = cfg.device

        self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name,
                                        trainable=cfg.fine_tune,
                                        output_size=cfg.bert_output_size,
                                        activation=self.activation)

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=self.bert_encoder.get_output_dims(),
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(
                input_size=self.entity_span_extractor.get_output_dims(),
                output_size=self.ent_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.context_span_extractor = CNNSpanExtractor(
            input_size=self.bert_encoder.get_output_dims(),
            num_filters=cfg.context_cnn_output_channels,
            ngram_filter_sizes=cfg.context_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.context_output_size > 0:
            self.context2hidden = BertLinear(
                input_size=self.context_span_extractor.get_output_dims(),
                output_size=self.context_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.context_output_size = self.context_span_extractor.get_output_dims(
            )
            self.context2hidden = lambda x: x

        if self.output_size > 0:
            self.mlp = BertLinear(input_size=2 * self.ent_output_size +
                                  3 * self.context_output_size,
                                  output_size=self.output_size,
                                  activation=self.activation,
                                  dropout=cfg.dropout)
        else:
            self.output_size = 2 * self.ent_output_size + 3 * self.context_output_size
            self.mlp = lambda x: x

        self.entity_pretrained_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.ent_output_size, label_size=18)

        self.masked_token_mlp = BertLinear(
            input_size=self.bert_encoder.get_output_dims(),
            output_size=self.bert_encoder.get_output_dims(),
            activation=self.activation)

        self.token_vocab_size = self.bert_encoder.bert_model.embeddings.word_embeddings.weight.size(
        )[0]
        self.masked_token_decoder = nn.Linear(
            self.bert_encoder.get_output_dims(),
            self.token_vocab_size,
            bias=False)
        self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02)
        self.masked_token_decoder_bias = nn.Parameter(
            torch.zeros(self.token_vocab_size))

        clone_weights(self.masked_token_decoder,
                      self.bert_encoder.bert_model.embeddings.word_embeddings)

        self.masked_token_loss = nn.CrossEntropyLoss()

        self.permutation_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.output_size, label_size=120)

        self.confused_context_decoder = nn.Linear(self.output_size, 1)
        self.confused_context_decoder.weight.data.normal_(mean=0.0, std=0.02)
        self.confused_context_decoder.bias.data.zero_()

        self.entity_mention_index_tensor = torch.LongTensor([2, 0, 3, 1, 4])
        if self.device > -1:
            self.entity_mention_index_tensor = self.entity_mention_index_tensor.cuda(
                device=self.device, non_blocking=True)

    def forward(self, batch_inputs, pretrain_task=''):
        """This function propagates forwardly

        Arguments:
            batch_inputs {dict} -- batch inputs

        Keyword Arguments:
            pretrain_task {str} -- pretraining task (default: {''})
        
        Returns:
            dict -- results
        """

        if pretrain_task == 'masked_entity_typing':
            return self.masked_entity_typing(batch_inputs)
        elif pretrain_task == 'masked_entity_token_prediction':
            return self.masked_entity_token_prediction(batch_inputs)
        elif pretrain_task == 'entity_mention_permutation':
            return self.permutation_prediction(batch_inputs)
        elif pretrain_task == 'confused_context':
            return self.confused_context_prediction(batch_inputs)

    def seq_decoder(self, seq_inputs, seq_mask=None, seq_labels=None):
        results = {}
        seq_outpus = self.masked_token_decoder(
            seq_inputs) + self.masked_token_decoder_bias
        seq_log_probs = F.log_softmax(seq_outpus, dim=2)
        seq_preds = seq_log_probs.argmax(dim=2)
        results['predict'] = seq_preds

        if seq_labels is not None:
            if seq_mask is not None:
                active_loss = seq_mask.view(-1) == 1
                active_outputs = seq_outpus.view(
                    -1, self.token_vocab_size)[active_loss]
                active_labels = seq_labels.view(-1)[active_loss]
                no_pad_avg_loss = self.masked_token_loss(
                    active_outputs, active_labels)
                results['loss'] = no_pad_avg_loss
            else:
                avg_loss = self.masked_token_loss(
                    seq_outpus.view(-1, self.token_vocab_size),
                    seq_labels.view(-1))
                results['loss'] = avg_loss

        return results

    def masked_entity_typing(self, batch_inputs):
        """This function pretrains masked entity typing task.
        
        Arguments:
            batch_inputs {dict} -- batch inputs
        """

        seq_wordpiece_tokens_reprs, _ = self.bert_encoder(
            batch_inputs['tokens_id'])
        batch_inputs['seq_wordpiece_tokens_reprs'] = seq_wordpiece_tokens_reprs
        batch_inputs['seq_tokens_reprs'] = batched_index_select(
            seq_wordpiece_tokens_reprs, batch_inputs['tokens_index'])

        all_ents = []
        all_ents_labels = []
        all_seq_tokens_reprs = []
        for ent_spans, ent_labels, seq_tokens_reprs in zip(
                batch_inputs['ent_spans'], batch_inputs['ent_labels'],
                batch_inputs['seq_tokens_reprs']):
            all_ents.extend([span[0], span[1] - 1] for span in ent_spans)
            all_ents_labels.extend(ent_label for ent_label in ent_labels)
            all_seq_tokens_reprs.extend(seq_tokens_reprs
                                        for _ in range(len(ent_spans)))

        if self.ent_batch_size > 0:
            all_entity_typing_loss = []
            for idx in range(0, len(all_ents), self.ent_batch_size):
                batch_ents_tensor = torch.LongTensor(
                    all_ents[idx:idx + self.ent_batch_size]).unsqueeze(1)
                if self.device > -1:
                    batch_ents_tensor = batch_ents_tensor.cuda(
                        device=self.device, non_blocking=True)

                batch_seq_tokens_reprs = torch.stack(
                    all_seq_tokens_reprs[idx:idx + self.ent_batch_size])

                batch_ents_feature = self.ent2hidden(
                    self.entity_span_extractor(batch_seq_tokens_reprs,
                                               batch_ents_tensor).squeeze(1))

                batch_ents_labels = torch.LongTensor(
                    all_ents_labels[idx:idx + self.ent_batch_size])
                if self.device > -1:
                    batch_ents_labels = batch_ents_labels.cuda(
                        device=self.device, non_blocking=True)

                entity_typing_outputs = self.entity_pretrained_decoder(
                    batch_ents_feature, batch_ents_labels)

                all_entity_typing_loss.append(entity_typing_outputs['loss'])

            if len(all_entity_typing_loss) != 0:
                entity_typing_loss = sum(all_entity_typing_loss) / len(
                    all_entity_typing_loss)
            else:
                zero_loss = torch.Tensor([0])
                zero_loss.requires_grad = True
                if self.device > -1:
                    zero_loss = zero_loss.cuda(device=self.device,
                                               non_blocking=True)
                entity_typing_loss = zero_loss
        else:
            all_ents_tensor = torch.LongTensor(all_ents).unsqueeze(1)
            if self.device > -1:
                all_ents_tensor = all_ents_tensor.cuda(device=self.device,
                                                       non_blocking=True)
            all_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs)
            all_ents_feature = self.entity_span_extractor(
                all_seq_tokens_reprs, all_ents_tensor).squeeze(1)

            all_ents_feature = self.ent2hidden(all_ents_feature)

            all_ents_labels = torch.LongTensor(all_ents_labels)
            if self.device > -1:
                all_ents_labels = all_ents_labels.cuda(device=self.device,
                                                       non_blocking=True)

            entity_typing_outputs = self.entity_pretrained_decoder(
                all_ents_feature, all_ents_labels)

            entity_typing_loss = entity_typing_outputs['loss']

        outputs = {}
        outputs['loss'] = entity_typing_loss

        return outputs

    def masked_entity_token_prediction(self, batch_inputs):
        """This function pretrains masked entity tokens prediction task.
        
        Arguments:
            batch_inputs {dict} -- batch inputs
        """

        masked_seq_wordpiece_tokens_reprs, _ = self.bert_encoder(
            batch_inputs['tokens_id'])
        masked_seq_wordpiece_tokens_reprs = self.masked_token_mlp(
            masked_seq_wordpiece_tokens_reprs)

        if batch_inputs['masked_index'].sum() != 0:
            masked_entity_token_outputs = self.seq_decoder(
                seq_inputs=masked_seq_wordpiece_tokens_reprs,
                seq_mask=batch_inputs['masked_index'],
                seq_labels=batch_inputs['tokens_label'])
            masked_entity_token_loss = masked_entity_token_outputs['loss']
        else:
            zero_loss = torch.Tensor([0])
            zero_loss.requires_grad = True
            if self.device > -1:
                zero_loss = zero_loss.cuda(device=self.device,
                                           non_blocking=True)
            masked_entity_token_loss = zero_loss

        outputs = {}
        outputs['loss'] = masked_entity_token_loss
        return outputs

    def permutation_prediction(self, batch_inputs):
        """This function pretrains entity mention permutaiton prediction task.
        
        Arguments:
            batch_inputs {dict} -- batch inputs
        """

        all_permutation_feature = self.get_entity_mention_feature(
            batch_inputs['tokens_id'], batch_inputs['tokens_index'],
            batch_inputs['ent_mention'], batch_inputs['tokens_index_lens'])
        permutation_outputs = self.permutation_decoder(
            all_permutation_feature, batch_inputs['ent_mention_label'])
        permutation_loss = permutation_outputs['loss']

        outputs = {}
        outputs['loss'] = permutation_loss
        return outputs

    def confused_context_prediction(self, batch_inputs):
        """This function pretrains confused context prediction task.
        
        Arguments:
            batch_inputs {dict} -- batch inputs
        """

        all_confused_context_feature = self.get_entity_mention_feature(
            batch_inputs['confused_tokens_id'],
            batch_inputs['confused_tokens_index'],
            batch_inputs['confused_ent_mention'],
            batch_inputs['confused_tokens_index_lens'])
        all_truth_context_feature = self.get_entity_mention_feature(
            batch_inputs['origin_tokens_id'],
            batch_inputs['origin_tokens_index'],
            batch_inputs['origin_ent_mention'],
            batch_inputs['origin_tokens_index_lens'])
        confused_context_score = self.confused_context_decoder(
            all_confused_context_feature)
        truth_context_score = self.confused_context_decoder(
            all_truth_context_feature)
        rank_loss = torch.mean(
            torch.relu(5.0 - torch.abs(confused_context_score -
                                       truth_context_score)))

        outputs = {}
        outputs['loss'] = rank_loss
        return outputs

    def get_entity_mention_feature(self, batch_wordpiece_tokens,
                                   batch_wordpiece_tokens_index,
                                   batch_entity_mentions, batch_seq_lens):
        """This function extracts entity mention feature using CNN.
        
        Arguments:
            batch_wordpiece_tokens {tensor} -- batch wordpiece tokens
            batch_wordpiece_tokens_index {tensor} -- batch wordpiece tokens index
            batch_entity_mentions {list} -- batch entity mentions
            batch_seq_lens {list} -- batch sequence length list
        
        Returns:
            tensor -- entity mention feature
        """

        batch_seq_reprs, _ = self.bert_encoder(batch_wordpiece_tokens)
        batch_seq_reprs = batched_index_select(batch_seq_reprs,
                                               batch_wordpiece_tokens_index)

        entity_spans = []
        context_spans = []
        for entity_mention, seq_len in zip(batch_entity_mentions,
                                           batch_seq_lens):
            entity_spans.append([[entity_mention[0][0], entity_mention[0][1]],
                                 [entity_mention[1][0], entity_mention[1][1]]])
            context_spans.append([[0, entity_mention[0][0]],
                                  [entity_mention[0][1], entity_mention[1][0]],
                                  [entity_mention[1][1], seq_len]])

        entity_spans_tensor = torch.LongTensor(entity_spans)
        if self.device > -1:
            entity_spans_tensor = entity_spans_tensor.cuda(device=self.device,
                                                           non_blocking=True)

        context_spans_tensor = torch.LongTensor(context_spans)
        if self.device > -1:
            context_spans_tensor = context_spans_tensor.cuda(
                device=self.device, non_blocking=True)

        entity_feature = self.entity_span_extractor(batch_seq_reprs,
                                                    entity_spans_tensor)
        context_feature = self.context_span_extractor(batch_seq_reprs,
                                                      context_spans_tensor)

        entity_feature = self.ent2hidden(entity_feature)
        context_feature = self.context2hidden(context_feature)

        entity_mention_feature = torch.cat([
            context_feature[:, 0, :], entity_feature[:, 0, :],
            context_feature[:, 1, :], entity_feature[:, 1, :],
            context_feature[:, 2, :]
        ],
                                           dim=-1).view(
                                               len(batch_wordpiece_tokens), -1)

        entity_mention_feature = self.mlp(entity_mention_feature)

        return entity_mention_feature