Beispiel #1
0
class BertCRF(BertPreTrainedModel):
    def __init__(self, cfig):
        super(BertCRF, self).__init__(cfig)

        #self.device = cfig.device
        self.num_labels = len(cfig.label2idx)
        self.bert = BertModel(cfig)
        self.dropout = nn.Dropout(cfig.hidden_dropout_prob)
        self.classifier = nn.Linear(cfig.hidden_size, len(cfig.label2idx))
        self.inferencer = LinearCRF(cfig)

        self.init_weights()

    def forward(self, input_ids, input_seq_lens=None, annotation_mask=None, labels=None,
                attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, add_crf=False):

        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask)
        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)   # (batch_size, seq_length, hidden_size)
        logits = self.classifier(sequence_output)  # (batch_size, seq_length, num_labels)

        if labels is not None:
            batch_size = input_ids.size(0)
            sent_len = input_ids.size(1)  # one batch max seq length
            maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(self.device)
            mask = torch.le(maskTemp, input_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device)

            unlabed_score, labeled_score = self.inferencer(logits, input_seq_lens, labels, attention_mask)
            return unlabed_score - labeled_score

        else:
            bestScores, decodeIdx = self.inferencer.decode(logits, input_seq_lens, annotation_mask)

            return bestScores, decodeIdx

    # obsolete
    def decode(self, input_ids, input_seq_lens=None, annotation_mask=None, attention_mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode the batch input
        :param batchInput:
        :return:
        """
        features = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=None, position_ids=None, head_mask=None)
        features = self.dropout(features)   # (batch_size, seq_length, hidden_size)
        logits = self.classifier(features)  # (batch_size, seq_length, num_labels)

        bestScores, decodeIdx = self.inferencer.decode(logits, input_seq_lens, annotation_mask)
        return bestScores, decodeIdx
Beispiel #2
0
class NNCRF(nn.Module):
    def __init__(self, config, print_info: bool = True):
        super(NNCRF, self).__init__()
        self.device = config.device
        self.encoder = BiLSTMEncoder(config, print_info=print_info)
        self.inferencer = LinearCRF(config, print_info=print_info)

    @overrides
    def forward(self, sent_emb_tensor: torch.Tensor,
                type_id_tensor: torch.Tensor, sent_seq_lens: torch.Tensor,
                batch_context_emb: torch.Tensor, chars: torch.Tensor,
                char_seq_lens: torch.Tensor,
                tags: torch.Tensor) -> torch.Tensor:
        """
        Calculate the negative loglikelihood.
        :param words: (batch_size x max_seq_len)
        :param word_seq_lens: (batch_size)
        :param batch_context_emb: (batch_size x max_seq_len x context_emb_size)
        :param chars: (batch_size x max_seq_len x max_char_len)
        :param char_seq_lens: (batch_size x max_seq_len)
        :param tags: (batch_size x max_seq_len)
        :return: the total negative log-likelihood loss
        """
        # print("sents: ",sents)
        lstm_scores = self.encoder(sent_emb_tensor, type_id_tensor,
                                   sent_seq_lens, batch_context_emb, chars,
                                   char_seq_lens)
        # lstm_scores = self.encoder(sent_emb_tensor, sent_seq_lens, chars, char_seq_lens)
        batch_size = sent_emb_tensor.size(0)
        sent_len = sent_emb_tensor.size(1)
        maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(
            1, sent_len).expand(batch_size, sent_len).to(self.device)
        mask = torch.le(
            maskTemp,
            sent_seq_lens.view(batch_size, 1).expand(batch_size,
                                                     sent_len)).to(self.device)
        unlabed_score, labeled_score = self.inferencer(lstm_scores,
                                                       sent_seq_lens, tags,
                                                       mask)
        return unlabed_score - labeled_score

    def decode(
        self, batchInput: Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
                                torch.Tensor, torch.Tensor, torch.Tensor,
                                torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode the batch input
        :param batchInput:
        :return:
        """
        wordSeqTensor, typeTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, tagSeqTensor = batchInput
        features = self.encoder(wordSeqTensor, typeTensor, wordSeqLengths,
                                batch_context_emb, charSeqTensor,
                                charSeqLengths)
        bestScores, decodeIdx = self.inferencer.decode(features,
                                                       wordSeqLengths)
        # print(bestScores, decodeIdx)
        return bestScores, decodeIdx
class NNCRF(nn.Module):

    def __init__(self, config, print_info: bool = True):
        super(NNCRF, self).__init__()
        self.device = config.device
        self.encoder = BiLSTMEncoder(config, print_info=print_info)
        self.inferencer = LinearCRF(config, print_info=print_info)

    @overrides
    def forward(self, words: torch.Tensor,
                    word_seq_lens: torch.Tensor,
                    batch_context_emb: torch.Tensor,
                    chars: torch.Tensor,
                    char_seq_lens: torch.Tensor,
                    annotation_mask : torch.Tensor,
                    marginals: torch.Tensor,
                    tags: torch.Tensor) -> torch.Tensor:
        """
        Calculate the negative loglikelihood.
        :param words: (batch_size x max_seq_len)
        :param word_seq_lens: (batch_size)
        :param batch_context_emb: (batch_size x max_seq_len x context_emb_size)
        :param chars: (batch_size x max_seq_len x max_char_len)
        :param char_seq_lens: (batch_size x max_seq_len)
        :param tags: (batch_size x max_seq_len)
        :return: the loss with shape (batch_size)
        """
        lstm_scores = self.encoder(words, word_seq_lens, batch_context_emb, chars, char_seq_lens)
        batch_size = words.size(0)
        sent_len = words.size(1)
        maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(self.device)
        mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device)
        unlabed_score, labeled_score =  self.inferencer(lstm_scores, word_seq_lens, tags, mask)
        return unlabed_score - labeled_score

    def decode(self, batchInput: Tuple) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode the batch input
        :param batchInput:
        :return:
        """
        wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, annotation_mask, marginals, tagSeqTensor = batchInput
        features = self.encoder(wordSeqTensor, wordSeqLengths, batch_context_emb,charSeqTensor,charSeqLengths)
        bestScores, decodeIdx = self.inferencer.decode(features, wordSeqLengths, annotation_mask)
        return bestScores, decodeIdx

    def get_marginal(self, batchInput: Tuple) -> torch.Tensor:
        wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, annotation_mask, marginals, tagSeqTensor = batchInput
        features = self.encoder(wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths)
        marginals = self.inferencer.compute_constrained_marginal(features, wordSeqLengths, annotation_mask)
        return marginals
class SoftSequenceNaive(nn.Module):
    def __init__(self, config,  encoder=None, print_info=True):
        super(SoftSequenceNaive, self).__init__()
        self.config = config
        self.device = config.device

        self.encoder = SoftEncoder(self.config)
        if encoder is not None:
            self.encoder = encoder

        self.label_size = config.label_size
        self.inferencer = LinearCRF(config, print_info=print_info)
        self.hidden2tag = nn.Linear(config.hidden_dim, self.label_size).to(self.device)

    def forward(self, word_seq_tensor: torch.Tensor,
                word_seq_lens: torch.Tensor,
                batch_context_emb: torch.Tensor,
                char_inputs: torch.Tensor,
                char_seq_lens: torch.Tensor, tags):

        batch_size = word_seq_tensor.size(0)
        max_sent_len = word_seq_tensor.size(1)

        output, sentence_mask, _, _ = \
            self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, None)

        lstm_scores = self.hidden2tag(output)
        maskTemp = torch.arange(1, max_sent_len + 1, dtype=torch.long).view(1, max_sent_len).expand(batch_size, max_sent_len).to(self.device)
        mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, max_sent_len)).to(self.device)

        if self.inferencer is not None:
            unlabeled_score, labeled_score = self.inferencer(lstm_scores, word_seq_lens, tags, mask)
            sequence_loss = unlabeled_score - labeled_score
        else:
            sequence_loss = self.compute_nll_loss(lstm_scores, tags, mask, word_seq_lens)

        return sequence_loss

    def decode(self, word_seq_tensor: torch.Tensor,
                word_seq_lens: torch.Tensor,
                batch_context_emb: torch.Tensor,
                char_inputs: torch.Tensor,
                char_seq_lens: torch.Tensor):

        soft_output, soft_sentence_mask, _, _ = \
                    self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, None)
        lstm_scores = self.hidden2tag(soft_output)
        if self.inferencer is not None:
            bestScores, decodeIdx = self.inferencer.decode(lstm_scores, word_seq_lens, None)
        return bestScores, decodeIdx
class NNCRF(nn.Module):
    def __init__(self, config, print_info: bool = True):
        super(NNCRF, self).__init__()
        self.device = config.device
        self.encoder = BiLSTMEncoder(config, print_info=print_info)
        self.inferencer = LinearCRF(config, print_info=print_info)

    @overrides
    def forward(self, words: torch.Tensor, word_seq_lens: torch.Tensor,
                batch_context_emb: torch.Tensor, chars: torch.Tensor,
                char_seq_lens: torch.Tensor,
                label_mask_tensor: torch.Tensor) -> torch.Tensor:
        """
        Calculate the negative loglikelihood.
        :param words: (batch_size x max_seq_len)
        :param word_seq_lens: (batch_size)
        :param batch_context_emb: (batch_size x max_seq_len x context_emb_size)
        :param chars: (batch_size x max_seq_len x max_char_len)
        :param char_seq_lens: (batch_size x max_seq_len)
        :param label_mask_tensor: (batch_size x max_seq_len x num_labels)
        :return: the loss with shape (batch_size)
        """
        lstm_scores = self.encoder(words, word_seq_lens, batch_context_emb,
                                   chars, char_seq_lens)
        unlabed_score, labeled_score = self.inferencer(lstm_scores,
                                                       word_seq_lens,
                                                       label_mask_tensor)
        return unlabed_score - labeled_score

    def decode(
        self, batchInput: Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
                                torch.Tensor, torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode the batch input
        :param batchInput:
        :return:
        """
        wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, tagSeqTensor = batchInput
        features = self.encoder(wordSeqTensor, wordSeqLengths,
                                batch_context_emb, charSeqTensor,
                                charSeqLengths)
        bestScores, decodeIdx = self.inferencer.decode(features,
                                                       wordSeqLengths)
        return bestScores, decodeIdx
class NNCRF(nn.Module):

    def __init__(self, config: Config,
                 print_info: bool = True):
        super(NNCRF, self).__init__()
        self.device = config.device
        self.encoder = BiLSTMEncoder(config, print_info=print_info)
        self.inferencer = None
        if config.use_crf_layer:
            self.inferencer = LinearCRF(config, print_info=print_info)

    @overrides
    def forward(self, words: torch.Tensor,
                    word_seq_lens: torch.Tensor,
                    batch_context_emb: torch.Tensor,
                    chars: torch.Tensor,
                    char_seq_lens: torch.Tensor,
                    tags: torch.Tensor) -> torch.Tensor:
        """
        Calculate the negative loglikelihood.
        :param words: (batch_size x max_seq_len)
        :param word_seq_lens: (batch_size)
        :param batch_context_emb: (batch_size x max_seq_len x context_emb_size)
        :param chars: (batch_size x max_seq_len x max_char_len)
        :param char_seq_lens: (batch_size x max_seq_len)
        :param tags: (batch_size x max_seq_len)
        :return: the loss with shape (batch_size)
        """
        batch_size = words.size(0)
        max_sent_len = words.size(1)
        #Shape: (batch_size, max_seq_len, num_labels)
        lstm_scores = self.encoder(words, word_seq_lens, batch_context_emb, chars, char_seq_lens)
        maskTemp = torch.arange(1, max_sent_len + 1, dtype=torch.long).view(1, max_sent_len).expand(batch_size, max_sent_len).to(self.device)
        mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, max_sent_len)).to(self.device)
        if self.inferencer is not None:
            unlabed_score, labeled_score =  self.inferencer(lstm_scores, word_seq_lens, tags, mask)
            loss = unlabed_score - labeled_score
        else:
            loss = self.compute_nll_loss(lstm_scores, tags, mask, word_seq_lens)
        return loss

    def decode(self, batchInput: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode the batch input
        :param batchInput:
        :return:
        """
        wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, tagSeqTensor = batchInput
        lstm_scores = self.encoder(wordSeqTensor, wordSeqLengths, batch_context_emb,charSeqTensor,charSeqLengths)
        if self.inferencer is not None:
            bestScores, decodeIdx = self.inferencer.decode(lstm_scores, wordSeqLengths)
        else:
            bestScores, decodeIdx = torch.max(lstm_scores, dim=2)
        return bestScores, decodeIdx

    def compute_nll_loss(self, candidate_scores, target, mask, word_seq_lens):
        """
        Directly compute the loss right after the linear layer instead of CRF layer.
        Partially taken from `masked_cross_entropy.py` (https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1)
        :param candidate_scores:
        :param target:
        :param mask:
        :param word_seq_lens:
        :return:
        """
        # logits_flat: (batch * max_len, num_classes)
        logits_flat = candidate_scores.view(-1, candidate_scores.size(-1))
        # log_probs_flat: (batch * max_len, num_classes)
        log_probs_flat = torch.log_softmax(logits_flat, dim=1)
        # target_flat: (batch * max_len, 1)
        target_flat = target.view(-1, 1)
        # losses_flat: (batch * max_len, 1)
        losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
        # losses: (batch, max_len)
        losses = losses_flat.view(*target.size())
        # # mask: (batch, max_len)
        # mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
        losses = losses * mask.float()
        # loss = losses.sum() / word_seq_lens.float().sum()
        loss = losses.sum()
        return loss
Beispiel #7
0
class SoftSequence(nn.Module):
    def __init__(self, config, softmatcher, encoder=None, print_info=True):
        super(SoftSequence, self).__init__()
        self.config = config
        self.device = config.device
        self.encoder = SoftEncoder(self.config)
        if encoder is not None:
            self.encoder = encoder

        self.softmatch_encoder = softmatcher.encoder
        self.softmatch_attention = softmatcher.attention
        self.label_size = config.label_size
        self.inferencer = LinearCRF(config, print_info=print_info)
        self.hidden2tag = nn.Linear(config.hidden_dim * 2,
                                    self.label_size).to(self.device)

        self.w1 = nn.Linear(config.hidden_dim,
                            config.hidden_dim // 2).to(self.device)
        self.w2 = nn.Linear(config.hidden_dim // 2,
                            config.hidden_dim // 2).to(self.device)
        self.attn1 = nn.Linear(config.hidden_dim // 2, 1).to(self.device)
        self.attn2 = nn.Linear(config.hidden_dim + config.hidden_dim // 2,
                               1).to(self.device)
        self.attn3 = nn.Linear(config.hidden_dim // 2, 1).to(self.device)

        self.applying = Variable(torch.randn(config.hidden_dim,
                                             config.hidden_dim // 2),
                                 requires_grad=True).to(self.device)
        self.tanh = nn.Tanh().to(self.device)
        self.perturb = nn.Dropout(config.dropout).to(self.device)

    def forward(self, word_seq_tensor: torch.Tensor,
                word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor,
                char_inputs: torch.Tensor, char_seq_lens: torch.Tensor,
                trigger_position, tags):

        batch_size = word_seq_tensor.size(0)
        max_sent_len = word_seq_tensor.size(1)

        output, sentence_mask, trigger_vec, trigger_mask = \
            self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens,
                         trigger_position)

        if trigger_vec is not None:
            trig_rep, sentence_vec_cat, trigger_vec_cat = self.softmatch_attention(
                output, sentence_mask, trigger_vec, trigger_mask)

            # attention
            weights = []
            for i in range(len(output)):
                trig_applied = self.tanh(
                    self.w1(output[i].unsqueeze(0)) +
                    self.w2(trig_rep[i].unsqueeze(0).unsqueeze(0)))
                x = self.attn1(trig_applied)  #63,1
                x = torch.mul(x.squeeze(0), sentence_mask[i].unsqueeze(1))
                x[x == 0] = float('-inf')
                weights.append(x)
            normalized_weights = F.softmax(torch.stack(weights), 1)
            attn_applied1 = torch.mul(
                normalized_weights.repeat(1, 1, output.size(2)), output)
        else:
            weights = []
            for i in range(len(output)):
                trig_applied = self.tanh(
                    self.w1(output[i].unsqueeze(0)) +
                    self.w1(output[i].unsqueeze(0)))
                x = self.attn1(trig_applied)  # 63,1
                x = torch.mul(x.squeeze(0), sentence_mask[i].unsqueeze(1))
                x[x == 0] = float('-inf')
                weights.append(x)
            normalized_weights = F.softmax(torch.stack(weights), 1)
            attn_applied1 = torch.mul(
                normalized_weights.repeat(1, 1, output.size(2)), output)

        output = torch.cat([output, attn_applied1], dim=2)
        lstm_scores = self.hidden2tag(output)
        maskTemp = torch.arange(1, max_sent_len + 1, dtype=torch.long).view(
            1, max_sent_len).expand(batch_size, max_sent_len).to(self.device)
        mask = torch.le(
            maskTemp,
            word_seq_lens.view(batch_size,
                               1).expand(batch_size,
                                         max_sent_len)).to(self.device)

        if self.inferencer is not None:
            unlabeled_score, labeled_score = self.inferencer(
                lstm_scores, word_seq_lens, tags, mask)
            sequence_loss = unlabeled_score - labeled_score
        else:
            sequence_loss = self.compute_nll_loss(lstm_scores, tags, mask,
                                                  word_seq_lens)

        return sequence_loss

    def decode_top(self, word_seq_tensor: torch.Tensor,
                   word_seq_lens: torch.Tensor,
                   batch_context_emb: torch.Tensor, char_inputs: torch.Tensor,
                   char_seq_lens: torch.Tensor, trig_rep):

        output, sentence_mask, _, _ = \
            self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, None)

        soft_output, soft_sentence_mask, _, _ = \
            self.softmatch_encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, None)
        soft_sent_rep = self.softmatch_attention.attention(
            soft_output, soft_sentence_mask)

        trig_vec = trig_rep[0]
        trig_key = trig_rep[1]

        n = soft_sent_rep.size(0)
        m = trig_vec.size(0)
        d = soft_sent_rep.size(1)

        soft_sent_rep_dist = soft_sent_rep.unsqueeze(1).expand(n, m, d)
        trig_vec_dist = trig_vec.unsqueeze(0).expand(n, m, d)

        dist = torch.pow(soft_sent_rep_dist - trig_vec_dist, 2).sum(2).sqrt()
        dvalue, dindices = torch.min(dist, dim=1)

        trigger_list = []
        for i in dindices.tolist():
            trigger_list.append(trig_vec[i])
        trig_rep = torch.stack(trigger_list)

        # attention
        weights = []
        for i in range(len(output)):
            trig_applied = self.tanh(
                self.w1(output[i].unsqueeze(0)) +
                self.w2(trig_rep[i].unsqueeze(0).unsqueeze(0)))
            x = self.attn1(trig_applied)
            x = torch.mul(x.squeeze(0), sentence_mask[i].unsqueeze(1))
            x[x == 0] = float('-inf')
            weights.append(x)
        normalized_weights = F.softmax(torch.stack(weights), 1)
        attn_applied1 = torch.mul(
            normalized_weights.repeat(1, 1, output.size(2)), output)

        output = torch.cat([output, attn_applied1], dim=2)

        lstm_scores = self.hidden2tag(output)
        bestScores, decodeIdx = self.inferencer.decode(lstm_scores,
                                                       word_seq_lens, None)

        return bestScores, decodeIdx
Beispiel #8
0
class SoftSequenceNaive(nn.Module):
    def __init__(self, config, encoder=None, print_info=True):
        super(SoftSequenceNaive, self).__init__()
        self.config = config
        self.device = config.device
        self.encoder = SoftEncoder(self.config)
        self.label_size = config.label_size
        self.inferencer = LinearCRF(config, print_info=print_info)
        self.hidden2tag = nn.Linear(config.hidden_dim,
                                    self.label_size).to(self.device)
        self.dsc_loss = DSCLoss(gamma=2)
        self.bert = AutoModel.from_pretrained(self.config.bert_path).to(
            self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.bert_path)

    def forward(self, word_seq_tensor: torch.Tensor,
                word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor,
                char_inputs: torch.Tensor, char_seq_lens: torch.Tensor, tags,
                one_batch_insts):
        word_seq_tensor, word_seq_lens = self.load_bert_embedding(
            one_batch_insts)
        batch_size = word_seq_tensor.size(0)
        max_sent_len = word_seq_tensor.size(1)

        output, sentence_mask = self.encoder(word_seq_tensor, word_seq_lens,
                                             batch_context_emb, char_inputs,
                                             char_seq_lens, one_batch_insts)

        lstm_scores = self.hidden2tag(output)
        maskTemp = torch.arange(1, max_sent_len + 1, dtype=torch.long).view(1, max_sent_len)\
            .expand(batch_size, max_sent_len).to(self.device)
        mask = torch.le(
            maskTemp,
            word_seq_lens.view(batch_size,
                               1).expand(batch_size,
                                         max_sent_len)).to(self.device)
        unlabeled_score, labeled_score = self.inferencer(
            lstm_scores, word_seq_lens, tags, mask)
        sequence_loss = unlabeled_score - labeled_score
        return sequence_loss

    def decode(self, word_seq_tensor: torch.Tensor,
               word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor,
               char_inputs: torch.Tensor, char_seq_lens: torch.Tensor,
               one_batch_insts):

        soft_output, soft_sentence_mask = \
                    self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, one_batch_insts)
        lstm_scores = self.hidden2tag(soft_output)
        bestScores, decodeIdx = self.inferencer.decode(lstm_scores,
                                                       word_seq_lens, None)
        return bestScores, decodeIdx

    def load_bert_embedding(self, insts):
        # sentence_list = []
        for sent in insts:
            # sentence = " ".join(str(w) for w in sent.input.words)
            # sentence_list.append(sentence)
            words = sent.input.words
            sent.word_ids = self.tokenizer.convert_tokens_to_ids(words)
        # sentence_list = tuple(sentence_list)
        # bert_embedding = self.get_bert_embedding(sentence_list)

        batch_size = len(insts)
        batch_data = insts

        # 统计这批数据的序列长度
        word_seq_len = torch.LongTensor(
            list(map(lambda inst: len(inst.input.words), batch_data)))
        max_seq_len = word_seq_len.max()
        word_seq_tensor = torch.zeros((batch_size, max_seq_len),
                                      dtype=torch.long)
        for idx in range(batch_size):
            word_seq_tensor[idx, :word_seq_len[idx]] = torch.LongTensor(
                batch_data[idx].word_ids)

        word_seq_tensor = word_seq_tensor.to(self.device)
        word_seq_len = word_seq_len.to(self.device)
        return word_seq_tensor, word_seq_len

    def get_bert_embedding(self, batch):

        final_dataset = []
        for sentence in batch:
            tokenized_sentence = [
                "[CLS]"
            ] + self.tokenizer.tokenize(sentence) + ["[SEP]"]
            # pooling operation (BERT - first)
            isSubword = False
            firstSubwordList = []
            for t_id, token in enumerate(tokenized_sentence):
                if token.startswith("#") == False:
                    isSubword = False
                    firstSubwordList.append(t_id)
                if isSubword:
                    continue
                if token.startswith("#"):
                    isSubword = True
            input_ids = torch.tensor(
                self.tokenizer.convert_tokens_to_ids(tokenized_sentence))
            final_dataset.append(input_ids)

        word_seq_lens = torch.LongTensor(
            list(map(lambda inst: inst.size(), final_dataset))).reshape(-1)
        # print(word_seq_lens)
        max_seq_len = word_seq_lens.max()
        word_seq_tensor = torch.zeros((self.config.batch_size, max_seq_len),
                                      dtype=torch.long)
        for idx in range(len(final_dataset)):
            tmp = torch.LongTensor(final_dataset[idx])
            word_seq_tensor[idx, :word_seq_lens[idx]] = tmp
        # embeddings = embeddings[0][0]
        # size0 = len(final_dataset)
        # final_dataset = torch.cat(final_dataset, dim=0).view(size0, -1, 768)
        word_seq_tensor = word_seq_tensor.to(self.device)
        word_seq_lens = word_seq_lens.to(self.device)
        return word_seq_tensor, word_seq_lens