def __init__(self, config, softmatcher, encoder=None, print_info=True): super(SoftSequence, self).__init__() self.config = config self.device = config.device self.encoder = SoftEncoder(self.config) if encoder is not None: self.encoder = encoder self.softmatch_encoder = softmatcher.encoder self.softmatch_attention = softmatcher.attention self.label_size = config.label_size self.inferencer = LinearCRF(config, print_info=print_info) self.hidden2tag = nn.Linear(config.hidden_dim * 2, self.label_size).to(self.device) self.w1 = nn.Linear(config.hidden_dim, config.hidden_dim // 2).to(self.device) self.w2 = nn.Linear(config.hidden_dim // 2, config.hidden_dim // 2).to(self.device) self.attn1 = nn.Linear(config.hidden_dim // 2, 1).to(self.device) self.attn2 = nn.Linear(config.hidden_dim + config.hidden_dim // 2, 1).to(self.device) self.attn3 = nn.Linear(config.hidden_dim // 2, 1).to(self.device) self.applying = Variable(torch.randn(config.hidden_dim, config.hidden_dim // 2), requires_grad=True).to(self.device) self.tanh = nn.Tanh().to(self.device) self.perturb = nn.Dropout(config.dropout).to(self.device)
def __init__(self, config: Config, print_info: bool = True): super(NNCRF, self).__init__() self.device = config.device self.encoder = BiLSTMEncoder(config, print_info=print_info) self.inferencer = None if config.use_crf_layer: self.inferencer = LinearCRF(config, print_info=print_info)
def __init__(self, cfig): super(BertCRF, self).__init__(cfig) #self.device = cfig.device self.num_labels = len(cfig.label2idx) self.bert = BertModel(cfig) self.dropout = nn.Dropout(cfig.hidden_dropout_prob) self.classifier = nn.Linear(cfig.hidden_size, len(cfig.label2idx)) self.inferencer = LinearCRF(cfig) self.init_weights()
def __init__(self, config, encoder=None, print_info=True): super(SoftSequenceNaive, self).__init__() self.config = config self.device = config.device self.encoder = SoftEncoder(self.config) if encoder is not None: self.encoder = encoder self.label_size = config.label_size self.inferencer = LinearCRF(config, print_info=print_info) self.hidden2tag = nn.Linear(config.hidden_dim, self.label_size).to(self.device)
def __init__(self, config, encoder=None, print_info=True): super(SoftSequenceNaive, self).__init__() self.config = config self.device = config.device self.encoder = SoftEncoder(self.config) self.label_size = config.label_size self.inferencer = LinearCRF(config, print_info=print_info) self.hidden2tag = nn.Linear(config.hidden_dim, self.label_size).to(self.device) self.dsc_loss = DSCLoss(gamma=2) self.bert = AutoModel.from_pretrained(self.config.bert_path).to( self.device) self.tokenizer = AutoTokenizer.from_pretrained(self.config.bert_path)
class BertCRF(BertPreTrainedModel): def __init__(self, cfig): super(BertCRF, self).__init__(cfig) #self.device = cfig.device self.num_labels = len(cfig.label2idx) self.bert = BertModel(cfig) self.dropout = nn.Dropout(cfig.hidden_dropout_prob) self.classifier = nn.Linear(cfig.hidden_size, len(cfig.label2idx)) self.inferencer = LinearCRF(cfig) self.init_weights() def forward(self, input_ids, input_seq_lens=None, annotation_mask=None, labels=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, add_crf=False): outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) # (batch_size, seq_length, hidden_size) logits = self.classifier(sequence_output) # (batch_size, seq_length, num_labels) if labels is not None: batch_size = input_ids.size(0) sent_len = input_ids.size(1) # one batch max seq length maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(self.device) mask = torch.le(maskTemp, input_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device) unlabed_score, labeled_score = self.inferencer(logits, input_seq_lens, labels, attention_mask) return unlabed_score - labeled_score else: bestScores, decodeIdx = self.inferencer.decode(logits, input_seq_lens, annotation_mask) return bestScores, decodeIdx # obsolete def decode(self, input_ids, input_seq_lens=None, annotation_mask=None, attention_mask=None) -> Tuple[torch.Tensor, torch.Tensor]: """ Decode the batch input :param batchInput: :return: """ features = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=None, position_ids=None, head_mask=None) features = self.dropout(features) # (batch_size, seq_length, hidden_size) logits = self.classifier(features) # (batch_size, seq_length, num_labels) bestScores, decodeIdx = self.inferencer.decode(logits, input_seq_lens, annotation_mask) return bestScores, decodeIdx
class NNCRF(nn.Module): def __init__(self, config, print_info: bool = True): super(NNCRF, self).__init__() self.device = config.device self.encoder = BiLSTMEncoder(config, print_info=print_info) self.inferencer = LinearCRF(config, print_info=print_info) @overrides def forward(self, words: torch.Tensor, word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, chars: torch.Tensor, char_seq_lens: torch.Tensor, annotation_mask : torch.Tensor, marginals: torch.Tensor, tags: torch.Tensor) -> torch.Tensor: """ Calculate the negative loglikelihood. :param words: (batch_size x max_seq_len) :param word_seq_lens: (batch_size) :param batch_context_emb: (batch_size x max_seq_len x context_emb_size) :param chars: (batch_size x max_seq_len x max_char_len) :param char_seq_lens: (batch_size x max_seq_len) :param tags: (batch_size x max_seq_len) :return: the loss with shape (batch_size) """ lstm_scores = self.encoder(words, word_seq_lens, batch_context_emb, chars, char_seq_lens) batch_size = words.size(0) sent_len = words.size(1) maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(self.device) mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device) unlabed_score, labeled_score = self.inferencer(lstm_scores, word_seq_lens, tags, mask) return unlabed_score - labeled_score def decode(self, batchInput: Tuple) -> Tuple[torch.Tensor, torch.Tensor]: """ Decode the batch input :param batchInput: :return: """ wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, annotation_mask, marginals, tagSeqTensor = batchInput features = self.encoder(wordSeqTensor, wordSeqLengths, batch_context_emb,charSeqTensor,charSeqLengths) bestScores, decodeIdx = self.inferencer.decode(features, wordSeqLengths, annotation_mask) return bestScores, decodeIdx def get_marginal(self, batchInput: Tuple) -> torch.Tensor: wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, annotation_mask, marginals, tagSeqTensor = batchInput features = self.encoder(wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths) marginals = self.inferencer.compute_constrained_marginal(features, wordSeqLengths, annotation_mask) return marginals
class NNCRF(nn.Module): def __init__(self, config, print_info: bool = True): super(NNCRF, self).__init__() self.device = config.device self.encoder = BiLSTMEncoder(config, print_info=print_info) self.inferencer = LinearCRF(config, print_info=print_info) @overrides def forward(self, sent_emb_tensor: torch.Tensor, type_id_tensor: torch.Tensor, sent_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, chars: torch.Tensor, char_seq_lens: torch.Tensor, tags: torch.Tensor) -> torch.Tensor: """ Calculate the negative loglikelihood. :param words: (batch_size x max_seq_len) :param word_seq_lens: (batch_size) :param batch_context_emb: (batch_size x max_seq_len x context_emb_size) :param chars: (batch_size x max_seq_len x max_char_len) :param char_seq_lens: (batch_size x max_seq_len) :param tags: (batch_size x max_seq_len) :return: the total negative log-likelihood loss """ # print("sents: ",sents) lstm_scores = self.encoder(sent_emb_tensor, type_id_tensor, sent_seq_lens, batch_context_emb, chars, char_seq_lens) # lstm_scores = self.encoder(sent_emb_tensor, sent_seq_lens, chars, char_seq_lens) batch_size = sent_emb_tensor.size(0) sent_len = sent_emb_tensor.size(1) maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view( 1, sent_len).expand(batch_size, sent_len).to(self.device) mask = torch.le( maskTemp, sent_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device) unlabed_score, labeled_score = self.inferencer(lstm_scores, sent_seq_lens, tags, mask) return unlabed_score - labeled_score def decode( self, batchInput: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decode the batch input :param batchInput: :return: """ wordSeqTensor, typeTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, tagSeqTensor = batchInput features = self.encoder(wordSeqTensor, typeTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths) bestScores, decodeIdx = self.inferencer.decode(features, wordSeqLengths) # print(bestScores, decodeIdx) return bestScores, decodeIdx
class SoftSequenceNaive(nn.Module): def __init__(self, config, encoder=None, print_info=True): super(SoftSequenceNaive, self).__init__() self.config = config self.device = config.device self.encoder = SoftEncoder(self.config) if encoder is not None: self.encoder = encoder self.label_size = config.label_size self.inferencer = LinearCRF(config, print_info=print_info) self.hidden2tag = nn.Linear(config.hidden_dim, self.label_size).to(self.device) def forward(self, word_seq_tensor: torch.Tensor, word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, char_inputs: torch.Tensor, char_seq_lens: torch.Tensor, tags): batch_size = word_seq_tensor.size(0) max_sent_len = word_seq_tensor.size(1) output, sentence_mask, _, _ = \ self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, None) lstm_scores = self.hidden2tag(output) maskTemp = torch.arange(1, max_sent_len + 1, dtype=torch.long).view(1, max_sent_len).expand(batch_size, max_sent_len).to(self.device) mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, max_sent_len)).to(self.device) if self.inferencer is not None: unlabeled_score, labeled_score = self.inferencer(lstm_scores, word_seq_lens, tags, mask) sequence_loss = unlabeled_score - labeled_score else: sequence_loss = self.compute_nll_loss(lstm_scores, tags, mask, word_seq_lens) return sequence_loss def decode(self, word_seq_tensor: torch.Tensor, word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, char_inputs: torch.Tensor, char_seq_lens: torch.Tensor): soft_output, soft_sentence_mask, _, _ = \ self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, None) lstm_scores = self.hidden2tag(soft_output) if self.inferencer is not None: bestScores, decodeIdx = self.inferencer.decode(lstm_scores, word_seq_lens, None) return bestScores, decodeIdx
class NNCRF(nn.Module): def __init__(self, config, print_info: bool = True): super(NNCRF, self).__init__() self.device = config.device self.encoder = BiLSTMEncoder(config, print_info=print_info) self.inferencer = LinearCRF(config, print_info=print_info) @overrides def forward(self, words: torch.Tensor, word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, chars: torch.Tensor, char_seq_lens: torch.Tensor, label_mask_tensor: torch.Tensor) -> torch.Tensor: """ Calculate the negative loglikelihood. :param words: (batch_size x max_seq_len) :param word_seq_lens: (batch_size) :param batch_context_emb: (batch_size x max_seq_len x context_emb_size) :param chars: (batch_size x max_seq_len x max_char_len) :param char_seq_lens: (batch_size x max_seq_len) :param label_mask_tensor: (batch_size x max_seq_len x num_labels) :return: the loss with shape (batch_size) """ lstm_scores = self.encoder(words, word_seq_lens, batch_context_emb, chars, char_seq_lens) unlabed_score, labeled_score = self.inferencer(lstm_scores, word_seq_lens, label_mask_tensor) return unlabed_score - labeled_score def decode( self, batchInput: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decode the batch input :param batchInput: :return: """ wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, tagSeqTensor = batchInput features = self.encoder(wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths) bestScores, decodeIdx = self.inferencer.decode(features, wordSeqLengths) return bestScores, decodeIdx
class NNCRF(nn.Module): def __init__(self, config: Config, print_info: bool = True): super(NNCRF, self).__init__() self.device = config.device self.encoder = BiLSTMEncoder(config, print_info=print_info) self.inferencer = None if config.use_crf_layer: self.inferencer = LinearCRF(config, print_info=print_info) @overrides def forward(self, words: torch.Tensor, word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, chars: torch.Tensor, char_seq_lens: torch.Tensor, tags: torch.Tensor) -> torch.Tensor: """ Calculate the negative loglikelihood. :param words: (batch_size x max_seq_len) :param word_seq_lens: (batch_size) :param batch_context_emb: (batch_size x max_seq_len x context_emb_size) :param chars: (batch_size x max_seq_len x max_char_len) :param char_seq_lens: (batch_size x max_seq_len) :param tags: (batch_size x max_seq_len) :return: the loss with shape (batch_size) """ batch_size = words.size(0) max_sent_len = words.size(1) #Shape: (batch_size, max_seq_len, num_labels) lstm_scores = self.encoder(words, word_seq_lens, batch_context_emb, chars, char_seq_lens) maskTemp = torch.arange(1, max_sent_len + 1, dtype=torch.long).view(1, max_sent_len).expand(batch_size, max_sent_len).to(self.device) mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, max_sent_len)).to(self.device) if self.inferencer is not None: unlabed_score, labeled_score = self.inferencer(lstm_scores, word_seq_lens, tags, mask) loss = unlabed_score - labeled_score else: loss = self.compute_nll_loss(lstm_scores, tags, mask, word_seq_lens) return loss def decode(self, batchInput: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """ Decode the batch input :param batchInput: :return: """ wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, tagSeqTensor = batchInput lstm_scores = self.encoder(wordSeqTensor, wordSeqLengths, batch_context_emb,charSeqTensor,charSeqLengths) if self.inferencer is not None: bestScores, decodeIdx = self.inferencer.decode(lstm_scores, wordSeqLengths) else: bestScores, decodeIdx = torch.max(lstm_scores, dim=2) return bestScores, decodeIdx def compute_nll_loss(self, candidate_scores, target, mask, word_seq_lens): """ Directly compute the loss right after the linear layer instead of CRF layer. Partially taken from `masked_cross_entropy.py` (https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1) :param candidate_scores: :param target: :param mask: :param word_seq_lens: :return: """ # logits_flat: (batch * max_len, num_classes) logits_flat = candidate_scores.view(-1, candidate_scores.size(-1)) # log_probs_flat: (batch * max_len, num_classes) log_probs_flat = torch.log_softmax(logits_flat, dim=1) # target_flat: (batch * max_len, 1) target_flat = target.view(-1, 1) # losses_flat: (batch * max_len, 1) losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) # losses: (batch, max_len) losses = losses_flat.view(*target.size()) # # mask: (batch, max_len) # mask = _sequence_mask(sequence_length=length, max_len=target.size(1)) losses = losses * mask.float() # loss = losses.sum() / word_seq_lens.float().sum() loss = losses.sum() return loss
class SoftSequence(nn.Module): def __init__(self, config, softmatcher, encoder=None, print_info=True): super(SoftSequence, self).__init__() self.config = config self.device = config.device self.encoder = SoftEncoder(self.config) if encoder is not None: self.encoder = encoder self.softmatch_encoder = softmatcher.encoder self.softmatch_attention = softmatcher.attention self.label_size = config.label_size self.inferencer = LinearCRF(config, print_info=print_info) self.hidden2tag = nn.Linear(config.hidden_dim * 2, self.label_size).to(self.device) self.w1 = nn.Linear(config.hidden_dim, config.hidden_dim // 2).to(self.device) self.w2 = nn.Linear(config.hidden_dim // 2, config.hidden_dim // 2).to(self.device) self.attn1 = nn.Linear(config.hidden_dim // 2, 1).to(self.device) self.attn2 = nn.Linear(config.hidden_dim + config.hidden_dim // 2, 1).to(self.device) self.attn3 = nn.Linear(config.hidden_dim // 2, 1).to(self.device) self.applying = Variable(torch.randn(config.hidden_dim, config.hidden_dim // 2), requires_grad=True).to(self.device) self.tanh = nn.Tanh().to(self.device) self.perturb = nn.Dropout(config.dropout).to(self.device) def forward(self, word_seq_tensor: torch.Tensor, word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, char_inputs: torch.Tensor, char_seq_lens: torch.Tensor, trigger_position, tags): batch_size = word_seq_tensor.size(0) max_sent_len = word_seq_tensor.size(1) output, sentence_mask, trigger_vec, trigger_mask = \ self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, trigger_position) if trigger_vec is not None: trig_rep, sentence_vec_cat, trigger_vec_cat = self.softmatch_attention( output, sentence_mask, trigger_vec, trigger_mask) # attention weights = [] for i in range(len(output)): trig_applied = self.tanh( self.w1(output[i].unsqueeze(0)) + self.w2(trig_rep[i].unsqueeze(0).unsqueeze(0))) x = self.attn1(trig_applied) #63,1 x = torch.mul(x.squeeze(0), sentence_mask[i].unsqueeze(1)) x[x == 0] = float('-inf') weights.append(x) normalized_weights = F.softmax(torch.stack(weights), 1) attn_applied1 = torch.mul( normalized_weights.repeat(1, 1, output.size(2)), output) else: weights = [] for i in range(len(output)): trig_applied = self.tanh( self.w1(output[i].unsqueeze(0)) + self.w1(output[i].unsqueeze(0))) x = self.attn1(trig_applied) # 63,1 x = torch.mul(x.squeeze(0), sentence_mask[i].unsqueeze(1)) x[x == 0] = float('-inf') weights.append(x) normalized_weights = F.softmax(torch.stack(weights), 1) attn_applied1 = torch.mul( normalized_weights.repeat(1, 1, output.size(2)), output) output = torch.cat([output, attn_applied1], dim=2) lstm_scores = self.hidden2tag(output) maskTemp = torch.arange(1, max_sent_len + 1, dtype=torch.long).view( 1, max_sent_len).expand(batch_size, max_sent_len).to(self.device) mask = torch.le( maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, max_sent_len)).to(self.device) if self.inferencer is not None: unlabeled_score, labeled_score = self.inferencer( lstm_scores, word_seq_lens, tags, mask) sequence_loss = unlabeled_score - labeled_score else: sequence_loss = self.compute_nll_loss(lstm_scores, tags, mask, word_seq_lens) return sequence_loss def decode_top(self, word_seq_tensor: torch.Tensor, word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, char_inputs: torch.Tensor, char_seq_lens: torch.Tensor, trig_rep): output, sentence_mask, _, _ = \ self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, None) soft_output, soft_sentence_mask, _, _ = \ self.softmatch_encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, None) soft_sent_rep = self.softmatch_attention.attention( soft_output, soft_sentence_mask) trig_vec = trig_rep[0] trig_key = trig_rep[1] n = soft_sent_rep.size(0) m = trig_vec.size(0) d = soft_sent_rep.size(1) soft_sent_rep_dist = soft_sent_rep.unsqueeze(1).expand(n, m, d) trig_vec_dist = trig_vec.unsqueeze(0).expand(n, m, d) dist = torch.pow(soft_sent_rep_dist - trig_vec_dist, 2).sum(2).sqrt() dvalue, dindices = torch.min(dist, dim=1) trigger_list = [] for i in dindices.tolist(): trigger_list.append(trig_vec[i]) trig_rep = torch.stack(trigger_list) # attention weights = [] for i in range(len(output)): trig_applied = self.tanh( self.w1(output[i].unsqueeze(0)) + self.w2(trig_rep[i].unsqueeze(0).unsqueeze(0))) x = self.attn1(trig_applied) x = torch.mul(x.squeeze(0), sentence_mask[i].unsqueeze(1)) x[x == 0] = float('-inf') weights.append(x) normalized_weights = F.softmax(torch.stack(weights), 1) attn_applied1 = torch.mul( normalized_weights.repeat(1, 1, output.size(2)), output) output = torch.cat([output, attn_applied1], dim=2) lstm_scores = self.hidden2tag(output) bestScores, decodeIdx = self.inferencer.decode(lstm_scores, word_seq_lens, None) return bestScores, decodeIdx
class SoftSequenceNaive(nn.Module): def __init__(self, config, encoder=None, print_info=True): super(SoftSequenceNaive, self).__init__() self.config = config self.device = config.device self.encoder = SoftEncoder(self.config) self.label_size = config.label_size self.inferencer = LinearCRF(config, print_info=print_info) self.hidden2tag = nn.Linear(config.hidden_dim, self.label_size).to(self.device) self.dsc_loss = DSCLoss(gamma=2) self.bert = AutoModel.from_pretrained(self.config.bert_path).to( self.device) self.tokenizer = AutoTokenizer.from_pretrained(self.config.bert_path) def forward(self, word_seq_tensor: torch.Tensor, word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, char_inputs: torch.Tensor, char_seq_lens: torch.Tensor, tags, one_batch_insts): word_seq_tensor, word_seq_lens = self.load_bert_embedding( one_batch_insts) batch_size = word_seq_tensor.size(0) max_sent_len = word_seq_tensor.size(1) output, sentence_mask = self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, one_batch_insts) lstm_scores = self.hidden2tag(output) maskTemp = torch.arange(1, max_sent_len + 1, dtype=torch.long).view(1, max_sent_len)\ .expand(batch_size, max_sent_len).to(self.device) mask = torch.le( maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, max_sent_len)).to(self.device) unlabeled_score, labeled_score = self.inferencer( lstm_scores, word_seq_lens, tags, mask) sequence_loss = unlabeled_score - labeled_score return sequence_loss def decode(self, word_seq_tensor: torch.Tensor, word_seq_lens: torch.Tensor, batch_context_emb: torch.Tensor, char_inputs: torch.Tensor, char_seq_lens: torch.Tensor, one_batch_insts): soft_output, soft_sentence_mask = \ self.encoder(word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens, one_batch_insts) lstm_scores = self.hidden2tag(soft_output) bestScores, decodeIdx = self.inferencer.decode(lstm_scores, word_seq_lens, None) return bestScores, decodeIdx def load_bert_embedding(self, insts): # sentence_list = [] for sent in insts: # sentence = " ".join(str(w) for w in sent.input.words) # sentence_list.append(sentence) words = sent.input.words sent.word_ids = self.tokenizer.convert_tokens_to_ids(words) # sentence_list = tuple(sentence_list) # bert_embedding = self.get_bert_embedding(sentence_list) batch_size = len(insts) batch_data = insts # 统计这批数据的序列长度 word_seq_len = torch.LongTensor( list(map(lambda inst: len(inst.input.words), batch_data))) max_seq_len = word_seq_len.max() word_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long) for idx in range(batch_size): word_seq_tensor[idx, :word_seq_len[idx]] = torch.LongTensor( batch_data[idx].word_ids) word_seq_tensor = word_seq_tensor.to(self.device) word_seq_len = word_seq_len.to(self.device) return word_seq_tensor, word_seq_len def get_bert_embedding(self, batch): final_dataset = [] for sentence in batch: tokenized_sentence = [ "[CLS]" ] + self.tokenizer.tokenize(sentence) + ["[SEP]"] # pooling operation (BERT - first) isSubword = False firstSubwordList = [] for t_id, token in enumerate(tokenized_sentence): if token.startswith("#") == False: isSubword = False firstSubwordList.append(t_id) if isSubword: continue if token.startswith("#"): isSubword = True input_ids = torch.tensor( self.tokenizer.convert_tokens_to_ids(tokenized_sentence)) final_dataset.append(input_ids) word_seq_lens = torch.LongTensor( list(map(lambda inst: inst.size(), final_dataset))).reshape(-1) # print(word_seq_lens) max_seq_len = word_seq_lens.max() word_seq_tensor = torch.zeros((self.config.batch_size, max_seq_len), dtype=torch.long) for idx in range(len(final_dataset)): tmp = torch.LongTensor(final_dataset[idx]) word_seq_tensor[idx, :word_seq_lens[idx]] = tmp # embeddings = embeddings[0][0] # size0 = len(final_dataset) # final_dataset = torch.cat(final_dataset, dim=0).view(size0, -1, 768) word_seq_tensor = word_seq_tensor.to(self.device) word_seq_lens = word_seq_lens.to(self.device) return word_seq_tensor, word_seq_lens