def test_case2(self): # 测试CRF是否正常work。 import json import torch from fastNLP import seq_len_to_mask with open('tests/data_for_tests/modules/decoder/crf.json', 'r') as f: data = json.load(f) bio_logits = torch.FloatTensor(data['bio_logits']) bio_scores = data['bio_scores'] bio_path = data['bio_path'] bio_trans_m = torch.FloatTensor(data['bio_trans_m']) bio_seq_lens = torch.LongTensor(data['bio_seq_lens']) bmes_logits = torch.FloatTensor(data['bmes_logits']) bmes_scores = data['bmes_scores'] bmes_path = data['bmes_path'] bmes_trans_m = torch.FloatTensor(data['bmes_trans_m']) bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens']) labels = ['O'] for label in ['X', 'Y']: for tag in 'BI': labels.append('{}-{}'.format(tag, label)) id2label = {idx: label for idx, label in enumerate(labels)} num_tags = len(id2label) mask = seq_len_to_mask(bio_seq_lens) from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, include_start_end=True)) fast_CRF.trans_m.data = bio_trans_m fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) # score equal self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) # seq equal self.assertListEqual(bio_path, fast_res[0]) labels = [] for label in ['X', 'Y']: for tag in 'BMES': labels.append('{}-{}'.format(tag, label)) id2label = {idx: label for idx, label in enumerate(labels)} num_tags = len(id2label) mask = seq_len_to_mask(bmes_seq_lens) from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, encoding_type='BMES', include_start_end=True)) fast_CRF.trans_m.data = bmes_trans_m fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) # score equal self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) # seq equal self.assertListEqual(bmes_path, fast_res[0])
def evaluate(self, pred, target, seq_len=None): """ evaluate函数将针对一个批次的预测结果做评价指标的累计 :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). 如果mask也被传进来的话seq_len会被忽略. """ # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value if not isinstance(pred, torch.Tensor): raise TypeError( f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(pred)}.") if not isinstance(target, torch.Tensor): raise TypeError( f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(target)}.") if seq_len is not None and not isinstance(seq_len, torch.Tensor): raise TypeError( f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(seq_len)}.") if seq_len is not None and target.dim() > 1: max_len = target.size(1) masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) else: masks = torch.ones_like(target).long().to(target.device) masks = masks.eq(False) if pred.dim() == target.dim(): pass elif pred.dim() == target.dim() + 1: pred = pred.argmax(dim=-1) if seq_len is None and target.dim() > 1: logger.warning( "You are not passing `seq_len` to exclude pad when calculate accuracy." ) else: raise RuntimeError( f"In {_get_func_signature(self.evaluate)}, when pred have " f"size:{pred.size()}, target should have size: {pred.size()} or " f"{pred.size()[:-1]}, got {target.size()}.") target_idxes = set(target.reshape(-1).tolist()) target = target.to(pred) for target_idx in target_idxes: self._tp[target_idx] += torch.sum( (pred == target_idx).long().masked_fill( target != target_idx, 0).masked_fill(masks, 0)).item() self._fp[target_idx] += torch.sum( (pred != target_idx).long().masked_fill( target != target_idx, 0).masked_fill(masks, 0)).item() self._fn[target_idx] += torch.sum( (pred == target_idx).long().masked_fill( target == target_idx, 0).masked_fill(masks, 0)).item()
def forward(self, chars, bigrams, seq_len, target): embed_char = self.char_embed(chars) if self.use_bigram: embed_bigram = self.bigram_embed(bigrams) embedding = torch.cat([embed_char, embed_bigram], dim=-1) else: embedding = embed_char embedding = self.embed_dropout(embedding) encoded_h, encoded_c = self.encoder(embedding, seq_len) encoded_h = self.output_dropout(encoded_h) pred = self.output(encoded_h) mask = seq_len_to_mask(seq_len) # pred = self.crf(pred) # batch_size, sent_len = pred.shape[0], pred.shape[1] # loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) if self.training: loss = self.crf(pred, target, mask) return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) return {'pred': pred}
def test_case(self): vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab.add_word_lst("Another test !".split()) embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) encoder_output = torch.randn(2, 3, 10) tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) src_seq_len = torch.LongTensor([3, 2]) encoder_mask = seq_len_to_mask(src_seq_len) for flag in [True, False]: for attention in [True, False]: with self.subTest(bind_decoder_input_output_embed=flag, attention=attention): decoder = LSTMSeq2SeqDecoder( embed=embed, num_layers=2, hidden_size=10, dropout=0.3, bind_decoder_input_output_embed=flag, attention=attention) state = decoder.init_state(encoder_output, encoder_mask) output = decoder(tgt_words_idx, state) self.assertEqual(tuple(output.size()), (2, 4, len(vocab)))
def test_case(self): vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab.add_word_lst("Another test !".split()) embed = StaticEmbedding(vocab, embedding_dim=10) encoder_output = torch.randn(2, 3, 10) src_seq_len = torch.LongTensor([3, 2]) encoder_mask = seq_len_to_mask(src_seq_len) for flag in [True, False]: with self.subTest(bind_decoder_input_output_embed=flag): decoder = TransformerSeq2SeqDecoder( embed=embed, pos_embed=None, d_model=10, num_layers=2, n_head=5, dim_ff=20, dropout=0.1, bind_decoder_input_output_embed=True) state = decoder.init_state(encoder_output, encoder_mask) output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) self.assertEqual(output.size(), (2, 4, len(vocab)))
def _forward(self, words, seq_len, target=None): words = self.embedding(words) outputs, _ = self.lstm(words, seq_len) self.dropout(outputs) logits = F.log_softmax(self.fc(outputs), dim=-1) if target is not None: loss = self.crf(logits, target, seq_len_to_mask(seq_len, max_len=logits.size(1))).mean() return {Const.LOSS: loss} else: pred, _ = self.crf.viterbi_decode( logits, seq_len_to_mask(seq_len, max_len=logits.size(1))) return {Const.OUTPUT: pred}
def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e, target, chars_target=None): batch_size = lattice.size(0) max_seq_len_and_lex_num = lattice.size(1) max_seq_len = bigrams.size(1) words = lattice[:, :max_seq_len] mask = seq_len_to_mask(seq_len).bool() words.masked_fill_((~mask), self.vocabs['lattice'].padding_idx) encoded = self.bert_embedding(words) if self.after_bert == 'lstm': encoded, _ = self.lstm(encoded, seq_len) encoded = self.dropout(encoded) pred = self.output(encoded) if self.training: loss = self.crf(pred, target, mask).mean(dim=0) return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) result = {'pred': pred} return result
def test_case3(self): # 测试crf的loss不会出现负数 import torch from fastNLP.modules.decoder.crf import ConditionalRandomField from fastNLP.core.utils import seq_len_to_mask from torch import optim from torch import nn num_tags, include_start_end_trans = 4, True num_samples = 4 lengths = torch.randint(3, 50, size=(num_samples, )).long() max_len = lengths.max() tags = torch.randint(num_tags, size=(num_samples, max_len)) masks = seq_len_to_mask(lengths) feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) crf = ConditionalRandomField(num_tags, include_start_end_trans) optimizer = optim.SGD( [param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) for _ in range(10): loss = crf(feats, tags, masks).mean() optimizer.zero_grad() loss.backward() optimizer.step() if _ % 1000 == 0: print(loss) self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.")
def forward(self, chars, bigrams, seq_len, target, chars_target=None): # print('**self.training: {} **'.format(self.training)) batch_size = chars.size(0) max_seq_len = chars.size(1) chars_embed = self.char_embed(chars) if self.use_bigram: bigrams_embed = self.bigram_embed(bigrams) embedding = torch.cat([chars_embed, bigrams_embed], dim=-1) else: embedding = chars_embed if self.embed_dropout_pos == '0': embedding = self.embed_dropout(embedding) embedding = self.w_proj(embedding) if self.embed_dropout_pos == '1': embedding = self.embed_dropout(embedding) if self.use_abs_pos: embedding = self.pos_encode(embedding) if self.embed_dropout_pos == '2': embedding = self.embed_dropout(embedding) encoded = self.encoder(embedding, seq_len) if hasattr(self, 'output_dropout'): encoded = self.output_dropout(encoded) pred = self.output(encoded) mask = seq_len_to_mask(seq_len).bool() if self.mode['debug']: print('debug mode:finish!') exit(1208) if self.training: loss = self.crf(pred, target, mask).mean(dim=0) if self.self_supervised: # print('self supervised loss added!') chars_pred = self.output_self_supervised(encoded) chars_pred = chars_pred.view( size=[batch_size * max_seq_len, -1]) chars_target = chars_target.view( size=[batch_size * max_seq_len]) self_supervised_loss = self.loss_func(chars_pred, chars_target) # print('self_supervised_loss:{}'.format(self_supervised_loss)) # print('supervised_loss:{}'.format(loss)) loss += self_supervised_loss return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) result = {'pred': pred} if self.self_supervised: chars_pred = self.output_self_supervised(encoded) result['chars_pred'] = chars_pred return result
def forward(self, chars, bigrams, seq_len, target, skips_l2r_source, skips_l2r_word, lexicon_count, skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None): # print('skips_l2r_word_id:{}'.format(skips_l2r_word.size())) batch = chars.size(0) max_seq_len = chars.size(1) # max_lexicon_count = skips_l2r_word.size(2) embed_char = self.char_embed(chars) if self.use_bigram: embed_bigram = self.bigram_embed(bigrams) embedding = torch.cat([embed_char, embed_bigram], dim=-1) else: embedding = embed_char embed_nonword = self.embed_dropout(embedding) # skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1]) embed_word = self.word_embed(skips_l2r_word) embed_word = self.embed_dropout(embed_word) # embed_word = torch.reshape(embed_word,shape=[batch,max_seq_len,max_lexicon_count,-1]) encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count) if self.bidirectional: embed_word_back = self.word_embed(skips_r2l_word) embed_word_back = self.embed_dropout(embed_word_back) encoded_h_back, encoded_c_back = self.encoder_back( embed_nonword, seq_len, skips_r2l_source, embed_word_back, lexicon_count_back) encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1) encoded_h = self.output_dropout(encoded_h) pred = self.output(encoded_h) mask = seq_len_to_mask(seq_len) if self.training: loss = self.crf(pred, target, mask) return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) return {'pred': pred}
def forward(self, feats, seq_lens, gold_heads=None): """ max_len是包含root的 :param chars: batch_size x max_len :param ngrams: batch_size x max_len*ngram_per_char :param seq_lens: batch_size :param gold_heads: batch_size x max_len :param pre_chars: batch_size x max_len :param pre_ngrams: batch_size x max_len*ngram_per_char :return dict: parsing results arc_pred: [batch_size, seq_len, seq_len] label_pred: [batch_size, seq_len, seq_len] mask: [batch_size, seq_len] head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads """ # prepare embeddings batch_size,seq_len,_ = feats.shape # print('forward {} {}'.format(batch_size, seq_len)) # get sequence mask mask = seq_len_to_mask(seq_lens).long() # for arc biaffine # mlp, reduce dim feat = self.mlp(feats) arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] # biaffine arc classifier arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] # use gold or predicted arc to predict label if gold_heads is None or not self.training: # use greedy decoding in training if self.training or self.use_greedy_infer: heads = self.greedy_decoder(arc_pred, mask) else: heads = self.mst_decoder(arc_pred, mask) head_pred = heads else: assert self.training # must be training mode if gold_heads is None: heads = self.greedy_decoder(arc_pred, mask) head_pred = heads else: head_pred = None heads = gold_heads # heads: batch_size x max_len batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=feats.device).unsqueeze(1) label_head = label_head[batch_range, heads].contiguous() label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} if head_pred is not None: res_dict['head_pred'] = head_pred return res_dict
def get_batch_generation(self, samples_list, try_cuda=True): if not samples_list: return None if try_cuda: self.try_cuda() tensor_list = [] masked_indices_list = [] max_len = 0 output_tokens_list = [] seq_len = [] for sample in samples_list: masked_inputs_list = sample["masked_sentences"] tokens_list = [self.tokenizer.bos_token_id] for idx, masked_input in enumerate(masked_inputs_list): tokens_list.extend( self.tokenizer.encode(" " + masked_input.strip(), add_special_tokens=False)) tokens_list.append(self.tokenizer.eos_token_id) # tokens = torch.cat(tokens_list)[: self.max_sentence_length] tokens = torch.tensor(tokens_list)[:self.max_sentence_length] output_tokens_list.append(tokens.long().cpu().numpy()) seq_len.append(len(tokens)) if len(tokens) > max_len: max_len = len(tokens) tensor_list.append(tokens) masked_index = ( tokens == self.tokenizer.mask_token_id).nonzero().numpy() for x in masked_index: masked_indices_list.append([x[0]]) tokens_list = [] for tokens in tensor_list: pad_lenght = max_len - len(tokens) if pad_lenght > 0: pad_tensor = torch.full([pad_lenght], self.tokenizer.pad_token_id, dtype=torch.int) tokens = torch.cat((tokens, pad_tensor.long())) tokens_list.append(tokens) batch_tokens = torch.stack(tokens_list) seq_len = torch.LongTensor(seq_len) attn_mask = seq_len_to_mask(seq_len) with torch.no_grad(): # with utils.eval(self.model.model): self.model.eval() outputs = self.model( batch_tokens.long().to(device=self._model_device), attention_mask=attn_mask.to(device=self._model_device)) log_probs = outputs[0] return log_probs.cpu(), output_tokens_list, masked_indices_list
def prepare_env(): vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab.add_word_lst("Another test !".split()) embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) encoder_output = torch.randn(2, 3, 10) src_seq_len = torch.LongTensor([3, 2]) encoder_mask = seq_len_to_mask(src_seq_len) return embed, encoder_output, encoder_mask
def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): optimizer = optim.Adam(model.parameters(), lr=1e-2) mask = seq_len_to_mask(tgt_seq_len).eq(0) target = tgt_words_idx.masked_fill(mask, -100) for i in range(100): optimizer.zero_grad() pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size loss = F.cross_entropy(pred.transpose(1, 2), target) loss.backward() optimizer.step() right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() return right_count
def forward(self, words, seq_len=None): """ :param torch.LongTensor words: [batch_size, seq_len],句子中word的index :param torch.LongTensor seq_len: [batch,] 每个句子的长度 :return output: dict of torch.LongTensor, [batch_size, num_classes] """ x = self.embed(words) # [N,L] -> [N,L,C] if seq_len is not None: mask = seq_len_to_mask(seq_len) x = self.conv_pool(x, mask) else: x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.dropout(x) x = self.fc(x) # [N,C] -> [N, N_class] return {C.OUTPUT: x}
def _check_potentials(self, scores, lengths=None): semiring = self.semiring batch, N, N2, N3 = self._get_dimension_and_requires_grad(scores) assert N == N2 == N3, "Non-square potentials" if lengths is None: lengths = torch.LongTensor([N - 1] * batch).to(scores.device) else: assert max(lengths) <= N, "Length longer than N" scores = semiring.convert(scores) scores = scores.clone() # avoid leaf error when backward mask = seq_len_to_mask(lengths + 1, N) mask3d = (mask.unsqueeze(-1) * mask.unsqueeze(-2)).unsqueeze(-1) * mask.view( batch, 1, 1, N) semiring.zero_mask_(scores, ~mask3d) return scores, batch, N, lengths
def _forward(self, chars, bigrams, trigrams, seq_len, target=None): chars = self.char_embed(chars) if bigrams is not None: bigrams = self.bigram_embed(bigrams) chars = torch.cat([chars, bigrams], dim=-1) if trigrams is not None: trigrams = self.trigram_embed(trigrams) chars = torch.cat([chars, trigrams], dim=-1) output, _ = self.lstm(chars, seq_len) output = self.dropout(output) output = self.fc(output) output = F.log_softmax(output, dim=-1) mask = seq_len_to_mask(seq_len) if target is None: pred, _ = self.crf.viterbi_decode(output, mask) return {Const.OUTPUT: pred} else: loss = self.crf.forward(output, tags=target, mask=mask) return {Const.LOSS: loss}
def forward(self, chars, bigrams, seq_len, target): if self.debug: print_info('chars:{}'.format(chars.size())) print_info('bigrams:{}'.format(bigrams.size())) print_info('seq_len:{}'.format(seq_len.size())) print_info('target:{}'.format(target.size())) embed_char = self.char_embed(chars) if self.use_bigram: embed_bigram = self.bigram_embed(bigrams) embedding = torch.cat([embed_char, embed_bigram], dim=-1) else: embedding = embed_char embedding = self.embed_dropout(embedding) encoded_h, encoded_c = self.encoder(embedding, seq_len) encoded_h = self.output_dropout(encoded_h) pred = self.output(encoded_h) mask = seq_len_to_mask(seq_len) # pred = self.crf(pred) # batch_size, sent_len = pred.shape[0], pred.shape[1] # loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) if self.debug: print('debug mode:finish') exit(1208) if self.training: loss = self.crf(pred, target, mask) return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) return {'pred': pred}
def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e, pos_tag, target=None, chars_target=None): batch_size = lattice.size(0) max_seq_len_and_lex_num = lattice.size(1) max_seq_len = bigrams.size(1) words = lattice[:, :max_seq_len] mask = seq_len_to_mask(seq_len).bool() words.masked_fill_((~mask), self.vocabs['lattice'].padding_idx) encoded = self.bert_embedding(words) if self.use_pos_tag: pos_embed = self.pos_embedding(pos_tag) encoded = torch.cat([encoded, pos_embed], dim=-1) if self.after_bert == 'lstm': encoded, _ = self.lstm(encoded, seq_len) encoded = self.dropout(encoded) pred = self.output(encoded) if self.training: # loss = self.crf(pred, target, mask).mean(dim=0) loss = self.crf(emissions=pred, tags=target, mask=mask).mean(dim=0) return {'loss': -loss} else: pred = self.crf.decode(emissions=pred, mask=mask).squeeze(0) # pred, path = self.crf.viterbi_decode(pred, mask) # print(pred.shape) result = {'pred': pred} return result
def _forward(self, chars, bigrams=None, trigrams=None, seq_len=None, target=None): chars = self.char_embed(chars) if hasattr(self, 'bigram_embed'): bigrams = self.bigram_embed(bigrams) chars = torch.cat((chars, bigrams), dim=-1) if hasattr(self, 'trigram_embed'): trigrams = self.trigram_embed(trigrams) chars = torch.cat((chars, trigrams), dim=-1) feats, _ = self.lstm(chars, seq_len=seq_len) feats = self.fc(feats) feats = self.dropout(feats) logits = F.log_softmax(feats, dim=-1) mask = seq_len_to_mask(seq_len) if target is None: pred, _ = self.crf.viterbi_decode(logits, mask) return {C.OUTPUT: pred} else: loss = self.crf(logits, target, mask).mean() return {C.LOSS: loss}
def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e, target, chars_target=None): if self.mode['debug']: print('lattice:{}'.format(lattice)) print('bigrams:{}'.format(bigrams)) print('seq_len:{}'.format(seq_len)) print('lex_num:{}'.format(lex_num)) print('pos_s:{}'.format(pos_s)) print('pos_e:{}'.format(pos_e)) batch_size = lattice.size(0) max_seq_len_and_lex_num = lattice.size(1) max_seq_len = bigrams.size(1) raw_embed = self.lattice_embed(lattice) # raw_embed 是字和词的pretrain的embedding,但是是分别trian的,所以需要区分对待 if self.use_bigram: bigrams_embed = self.bigram_embed(bigrams) bigrams_embed = torch.cat([ bigrams_embed, torch.zeros(size=[ batch_size, max_seq_len_and_lex_num - max_seq_len, self.bigram_size ]).to(bigrams_embed) ], dim=1) raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1) else: raw_embed_char = raw_embed # print('raw_embed_char_1:{}'.format(raw_embed_char[:1,:3,-5:])) if self.use_bert: bert_pad_length = lattice.size(1) - max_seq_len char_for_bert = lattice[:, :max_seq_len] mask = seq_len_to_mask(seq_len).bool() char_for_bert = char_for_bert.masked_fill( (~mask), self.vocabs['lattice'].padding_idx) bert_embed = self.bert_embedding(char_for_bert) bert_embed = torch.cat([ bert_embed, torch.zeros( size=[batch_size, bert_pad_length, bert_embed.size(-1)], device=bert_embed.device, requires_grad=False) ], dim=-2) # print('bert_embed:{}'.format(bert_embed[:1, :3, -5:])) raw_embed_char = torch.cat([raw_embed_char, bert_embed], dim=-1) # print('raw_embed_char:{}'.format(raw_embed_char[:1,:3,-5:])) if self.embed_dropout_pos == '0': raw_embed_char = self.embed_dropout(raw_embed_char) raw_embed = self.gaz_dropout(raw_embed) # print('raw_embed_char_dp:{}'.format(raw_embed_char[:1,:3,-5:])) embed_char = self.char_proj(raw_embed_char) # print('char_proj:',list(self.char_proj.parameters())[0].data[:2][:2]) # print('embed_char_:{}'.format(embed_char[:1,:3,:4])) if self.mode['debug']: print('embed_char:{}'.format(embed_char[:2])) char_mask = seq_len_to_mask(seq_len, max_len=max_seq_len_and_lex_num).bool() # if self.embed_dropout_pos == '1': # embed_char = self.embed_dropout(embed_char) embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0) embed_lex = self.lex_proj(raw_embed) if self.mode['debug']: print('embed_lex:{}'.format(embed_lex[:2])) # if self.embed_dropout_pos == '1': # embed_lex = self.embed_dropout(embed_lex) lex_mask = (seq_len_to_mask(seq_len + lex_num).bool() ^ char_mask.bool()) embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0) assert char_mask.size(1) == lex_mask.size(1) # print('embed_char:{}'.format(embed_char[:1,:3,:4])) # print('embed_lex:{}'.format(embed_lex[:1,:3,:4])) embedding = embed_char + embed_lex if self.mode['debug']: print('embedding:{}'.format(embedding[:2])) if self.embed_dropout_pos == '1': embedding = self.embed_dropout(embedding) if self.use_abs_pos: embedding = self.abs_pos_encode(embedding, pos_s, pos_e) if self.embed_dropout_pos == '2': embedding = self.embed_dropout(embedding) # embedding = self.embed_dropout(embedding) # print('*1*') # print(embedding.size()) # print('merged_embedding:{}'.format(embedding[:1,:3,:4])) # exit() encoded = self.encoder(embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e) if hasattr(self, 'output_dropout'): encoded = self.output_dropout(encoded) encoded = encoded[:, :max_seq_len, :] pred = self.output(encoded) mask = seq_len_to_mask(seq_len).bool() if self.mode['debug']: print('debug mode:finish!') exit(1208) if self.training: loss = self.crf(pred, target, mask).mean(dim=0) if self.self_supervised: # print('self supervised loss added!') chars_pred = self.output_self_supervised(encoded) chars_pred = chars_pred.view( size=[batch_size * max_seq_len, -1]) chars_target = chars_target.view( size=[batch_size * max_seq_len]) self_supervised_loss = self.loss_func(chars_pred, chars_target) # print('self_supervised_loss:{}'.format(self_supervised_loss)) # print('supervised_loss:{}'.format(loss)) loss += self_supervised_loss return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) result = {'pred': pred} if self.self_supervised: chars_pred = self.output_self_supervised(encoded) result['chars_pred'] = chars_pred return result
def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e, target, span_label, attr_start_label, attr_end_label, chars_target=None): self.steps += 1 if self.mode['debug']: print('lattice:{} {}'.format(lattice.shape, lattice)) print('bigrams:{} {}'.format(bigrams.shape, bigrams)) print('seq_len:{} {}'.format(seq_len.shape, seq_len)) print('lex_num:{} {}'.format(lex_num.shape, lex_num)) print('pos_s:{} {}'.format(pos_s.shape, pos_s)) print('pos_e:{} {}'.format(pos_e.shape, pos_e)) print('span_label:{} {}'.format(span_label.shape, span_label)) print('attr_start_label:{} {}'.format(attr_start_label.shape, attr_start_label)) print('attr_end_label: {} {}'.format(attr_end_label.shape, attr_end_label)) exit(1228) batch_size = lattice.size(0) max_seq_len_and_lex_num = lattice.size(1) max_seq_len = bigrams.size(1) raw_embed = self.lattice_embed(lattice) # raw_embed 是字和词的pretrain的embedding,但是是分别trian的,所以需要区分对待 if self.use_bigram: bigrams_embed = self.bigram_embed(bigrams) bigrams_embed = torch.cat([bigrams_embed, torch.zeros(size=[batch_size, max_seq_len_and_lex_num - max_seq_len, self.bigram_size]).to(bigrams_embed)], dim=1) raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1) else: raw_embed_char = raw_embed # print('raw_embed_char_1:{}'.format(raw_embed_char[:1,:3,-5:])) if self.use_bert: bert_pad_length = lattice.size(1) - max_seq_len char_for_bert = lattice[:, :max_seq_len] mask = seq_len_to_mask(seq_len).bool() char_for_bert = char_for_bert.masked_fill((~mask), self.vocabs['lattice'].padding_idx) bert_embed = self.bert_embedding(char_for_bert) bert_embed = torch.cat([bert_embed, torch.zeros(size=[batch_size, bert_pad_length, bert_embed.size(-1)], device=bert_embed.device, requires_grad=False)], dim=-2) # print('bert_embed:{}'.format(bert_embed[:1, :3, -5:])) raw_embed_char = torch.cat([raw_embed_char, bert_embed], dim=-1) # print('raw_embed_char:{}'.format(raw_embed_char[:1,:3,-5:])) if self.embed_dropout_pos == '0': raw_embed_char = self.embed_dropout(raw_embed_char) raw_embed = self.gaz_dropout(raw_embed) # print('raw_embed_char_dp:{}'.format(raw_embed_char[:1,:3,-5:])) embed_char = self.char_proj(raw_embed_char) # print('char_proj:',list(self.char_proj.parameters())[0].data[:2][:2]) # print('embed_char_:{}'.format(embed_char[:1,:3,:4])) if self.mode['debug']: print('embed_char:{}'.format(embed_char[:2])) char_mask = seq_len_to_mask(seq_len, max_len=max_seq_len_and_lex_num).bool() # if self.embed_dropout_pos == '1': # embed_char = self.embed_dropout(embed_char) embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0) embed_lex = self.lex_proj(raw_embed) if self.mode['debug']: print('embed_lex:{}'.format(embed_lex[:2])) # if self.embed_dropout_pos == '1': # embed_lex = self.embed_dropout(embed_lex) lex_mask = (seq_len_to_mask(seq_len + lex_num).bool() ^ char_mask.bool()) embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0) assert char_mask.size(1) == lex_mask.size(1) # print('embed_char:{}'.format(embed_char[:1,:3,:4])) # print('embed_lex:{}'.format(embed_lex[:1,:3,:4])) embedding = embed_char + embed_lex if self.mode['debug']: print('embedding:{}'.format(embedding[:2])) if self.embed_dropout_pos == '1': embedding = self.embed_dropout(embedding) if self.use_abs_pos: embedding = self.abs_pos_encode(embedding, pos_s, pos_e) if self.embed_dropout_pos == '2': embedding = self.embed_dropout(embedding) # embedding = self.embed_dropout(embedding) # print('*1*') # print(embedding.size()) # print('merged_embedding:{}'.format(embedding[:1,:3,:4])) # exit() mask = seq_len_to_mask(seq_len).bool() # TODO: add ours PLE if self.new_tag_scheme: encodeds = [] for _i in range(self.ple_channel_num): encoded = self.encoder_list[_i](embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e) if hasattr(self, 'output_dropout'): encoded = self.output_dropout(encoded) encoded = encoded[:, :max_seq_len, :] encodeds.append(encoded) if self.ple_channel_num == 1: span_logits, attr_start_logits, attr_end_logits = self.ple(encodeds[0], encodeds[0], encodeds[0]) else: span_logits, attr_start_logits, attr_end_logits = self.ple(encodeds[0], encodeds[1], encodeds[2]) if self.training: inputs_seq_len = mask.sum(dim=-1).float() span_loss = (self.crf(span_logits, span_label, mask) / inputs_seq_len).mean(dim=0) attr_start_loss = self.attr_criterion(attr_start_logits.permute(0, 2, 1), attr_start_label) # B * S attr_start_loss = (torch.sum(attr_start_loss * mask.float(), dim=-1).float() / inputs_seq_len).mean() # B attr_end_loss = self.attr_criterion(attr_end_logits.permute(0, 2, 1), attr_end_label) # B * S attr_end_loss = (torch.sum(attr_end_loss * mask.float(), dim=-1).float() / inputs_seq_len).mean() # B loss = (self.span_loss_alpha * span_loss + attr_start_loss + attr_end_loss) / 3 # if torch.isnan(span_loss.mean()) or torch.abs(span_loss.mean()) > 50: if self.steps % 50 == 0: print(f"span_loss: {span_loss}; attr_start_loss: {attr_start_loss}; attr_end_loss: {attr_end_loss}") # loss = (attr_start_loss + attr_end_loss) / 3 return {"loss": loss} else: # span_pred, path = self.crf.viterbi_decode(span_logits, mask) attr_start_pred = attr_start_logits.argmax(dim=-1) attr_end_pred = attr_end_logits.argmax(dim=-1) ner_pred = convert_attr_seq_to_ner_seq(attr_start_pred, attr_end_pred, self.vocabs, tagscheme='BMOES') return {'pred': ner_pred} else: encoded = self.encoder(embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e) if hasattr(self, 'output_dropout'): encoded = self.output_dropout(encoded) encoded = encoded[:, :max_seq_len, :] pred = self.output(encoded) if self.mode['debug']: print('debug mode:finish!') exit(1208) if self.training: loss = self.crf(pred, target, mask).mean(dim=0) if self.self_supervised: # print('self supervised loss added!') chars_pred = self.output_self_supervised(encoded) chars_pred = chars_pred.view(size=[batch_size * max_seq_len, -1]) chars_target = chars_target.view(size=[batch_size * max_seq_len]) self_supervised_loss = self.loss_func(chars_pred, chars_target) # print('self_supervised_loss:{}'.format(self_supervised_loss)) # print('supervised_loss:{}'.format(loss)) loss += self_supervised_loss return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) result = {'pred': pred} if self.self_supervised: chars_pred = self.output_self_supervised(encoded) result['chars_pred'] = chars_pred return result
def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None, pre_trigrams=None): """ max_len是包含root的 :param chars: batch_size x max_len :param ngrams: batch_size x max_len*ngram_per_char :param seq_lens: batch_size :param gold_heads: batch_size x max_len :param pre_chars: batch_size x max_len :param pre_ngrams: batch_size x max_len*ngram_per_char :return dict: parsing results arc_pred: [batch_size, seq_len, seq_len] label_pred: [batch_size, seq_len, seq_len] mask: [batch_size, seq_len] head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads """ # prepare embeddings batch_size, seq_len = chars.shape # print('forward {} {}'.format(batch_size, seq_len)) # get sequence mask mask = seq_len_to_mask(seq_lens).long() chars = self.char_embed(chars) # [N,L] -> [N,L,C_0] bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1] trigrams = self.trigram_embed(trigrams) if pre_chars is not None: pre_chars = self.pre_char_embed(pre_chars) # pre_chars = self.pre_char_fc(pre_chars) chars = pre_chars + chars if pre_bigrams is not None: pre_bigrams = self.pre_bigram_embed(pre_bigrams) # pre_bigrams = self.pre_bigram_fc(pre_bigrams) bigrams = bigrams + pre_bigrams if pre_trigrams is not None: pre_trigrams = self.pre_trigram_embed(pre_trigrams) # pre_trigrams = self.pre_trigram_fc(pre_trigrams) trigrams = trigrams + pre_trigrams x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C] # encoder, extract features if self.training: x = drop_input_independent(x, self.dropout) sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) x = x[sort_idx] x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) feat, _ = self.encoder(x) # -> [N,L,C] feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) feat = feat[unsort_idx] feat = self.timestep_drop(feat) # for arc biaffine # mlp, reduce dim feat = self.mlp(feat) arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] # biaffine arc classifier arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] # use gold or predicted arc to predict label if gold_heads is None or not self.training: # use greedy decoding in training if self.training or self.use_greedy_infer: heads = self.greedy_decoder(arc_pred, mask) else: heads = self.mst_decoder(arc_pred, mask) head_pred = heads else: assert self.training # must be training mode if gold_heads is None: heads = self.greedy_decoder(arc_pred, mask) head_pred = heads else: head_pred = None heads = gold_heads # heads: batch_size x max_len batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) label_head = label_head[batch_range, heads].contiguous() label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] # 这里限制一下,只有当head为下一个时,才能预测app这个label arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ .repeat(batch_size, 1) # batch_size x max_len app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) app_masks[:, :, 1:] = 0 label_pred = label_pred.masked_fill(app_masks, -np.inf) res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} if head_pred is not None: res_dict['head_pred'] = head_pred return res_dict
def forward(self, sentence, aspect, pos_class, dep_tags, text_len, aspect_len, dep_rels, dep_heads, aspect_position, dep_dirs): ''' Forward takes: sentence: sentence_id of size (batch_size, text_length) aspect: aspect_id of size (batch_size, aspect_length) pos_class: pos_tag_id of size (batch_size, text_length) dep_tags: dep_tag_id of size (batch_size, text_length) text_len: (batch_size,) length of each sentence aspect_len: (batch_size, ) aspect length of each sentence dep_rels: (batch_size, text_length) relation dep_heads: (batch_size, text_length) which node adjacent to that node aspect_position: (batch_size, text_length) mask, with the position of aspect as 1 and others as 0 dep_dirs: (batch_size, text_length) the directions each node to the aspect ''' fmask = seq_len_to_mask(text_len).float() # fmask = (torch.zeros_like(sentence) != sentence).float() # (N,L), pad为0 # dmask = (torch.zeros_like(dep_tags) != dep_tags).float() # (N ,L) if self.training: mask = torch.rand(sentence.size()).lt(0.02).to(sentence.device) sentence = sentence.masked_fill(mask, 0) # mask = torch.rand(aspect.size()).lt(0.01).to(sentence.device) # aspect = aspect.masked_fill(mask, 0) feature = self.embed(sentence) # (N, L, D) aspect_feature = self.embed(aspect) # (N, L', D) feature = self.dropout(feature) aspect_feature = self.dropout(aspect_feature) if self.args.highway: feature = self.highway(feature) aspect_feature = self.highway(aspect_feature) feature, _ = self.bilstm(feature, seq_len=text_len) # (N,L,D) aspect_feature, _ = self.bilstm(aspect_feature, seq_len=aspect_len) #(N,L,D) aspect_mask = seq_len_to_mask(aspect_len) # aspect_feature = aspect_feature.masked_fill(aspect_mask.eq(0).unsqueeze(-1), 0) # aspect_feature = aspect_feature.sum(dim=1)/aspect_len.unsqueeze(1).float() aspect_feature = aspect_feature.masked_fill( aspect_mask.eq(0).unsqueeze(-1), -10000) aspect_feature, _ = aspect_feature.max(dim=1) # aspect_feature = aspect_feature.mean(dim=1) ############################################################################################ # do gat thing dep_feature = self.dep_embed(dep_tags) # dep_feature = self.dropout(dep_feature) dep_feature = F.dropout(dep_feature, p=0.7, training=self.training) if self.args.highway: dep_feature = self.highway_dep(dep_feature) dep_out = [ g(feature, dep_feature, fmask).unsqueeze(1) for g in self.gat_dep ] # (N, 1, D) * num_heads dep_out = torch.cat(dep_out, dim=1) # (N, H, D) # dep_out = dep_out.mean(dim = 1) # (N, D) dep_out, _ = dep_out.max(dim=1) # (N, D) if self.args.gat_attention_type == 'gcn': gat_out = self.gat(feature) # (N, L, D) fmask = fmask.unsqueeze(2) gat_out = gat_out * fmask gat_out = F.relu(torch.sum(gat_out, dim=1)) # (N, D) else: gat_out = [ g(feature, aspect_feature, fmask).unsqueeze(1) for g in self.gat ] gat_out = torch.cat(gat_out, dim=1) # gat_out = gat_out.mean(dim=1) gat_out, _ = gat_out.max(dim=1) feature_out = torch.cat([dep_out, gat_out], dim=1) # (N, D') # feature_out = gat_out ############################################################################################# feature_out = self.dropout(feature_out) # feature_out = F.dropout(feature_out, p=0.3, training=self.training) x = self.fcs(feature_out) logit = self.fc_final(x) return logit
def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e, target, chars_target=None): if self.mode['debug']: #以第一个sample为例的话 print('lattice:{}'.format(lattice)) #21+12个idx后面填充 print('bigrams:{}'.format(bigrams)) #21个idx,跟lattice的开头还不同 print('seq_len:{}'.format(seq_len)) #21 print('lex_num:{}'.format(lex_num)) #12 print('pos_s:{}'.format(pos_s)) #0,1,2, print('pos_e:{}'.format(pos_e)) #0,1,2, batch_size = lattice.size(0) max_seq_len_and_lex_num = lattice.size(1) max_seq_len = bigrams.size(1) raw_embed = self.lattice_embed(lattice) #取lattice的embedding if self.use_bigram: bigrams_embed = self.bigram_embed(bigrams) bigrams_embed = torch.cat([ bigrams_embed, torch.zeros(size=[ batch_size, max_seq_len_and_lex_num - max_seq_len, self.bigram_size ]).to(bigrams_embed) ], dim=1) raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1) # else: raw_embed_char = raw_embed # dim2 = 0 # dim3 = 2 if self.embed_dropout_pos == '0': raw_embed_char = self.embed_dropout(raw_embed_char) raw_embed = self.gaz_dropout(raw_embed) embed_char = self.char_proj(raw_embed_char) #linear char_mask = seq_len_to_mask(seq_len, max_len=max_seq_len_and_lex_num).bool() embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0) #torch.tensor embed_lex = self.lex_proj(raw_embed) lex_mask = (seq_len_to_mask(seq_len + lex_num).bool() ^ char_mask.bool()) embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0) assert char_mask.size(1) == lex_mask.size(1) embedding = embed_char + embed_lex #这里加的很诡异啊 if self.embed_dropout_pos == '1': embedding = self.embed_dropout(embedding) #dropout if self.use_abs_pos: embedding = self.abs_pos_encode(embedding, pos_s, pos_e) if self.embed_dropout_pos == '2': embedding = self.embed_dropout(embedding) encoded = self.encoder(embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e, print_=(self.batch_num == 327)) if hasattr(self, 'output_dropout'): encoded = self.output_dropout(encoded) encoded = encoded[:, :max_seq_len, :] pred = self.output(encoded) mask = seq_len_to_mask(seq_len).bool() if self.mode['debug']: print('debug mode:finish!') exit(1208) if self.training: loss = self.crf(pred, target, mask).mean(dim=0) if self.self_supervised: # print('self supervised loss added!') chars_pred = self.output_self_supervised(encoded) chars_pred = chars_pred.view( size=[batch_size * max_seq_len, -1]) chars_target = chars_target.view( size=[batch_size * max_seq_len]) self_supervised_loss = self.loss_func(chars_pred, chars_target) # print('self_supervised_loss:{}'.format(self_supervised_loss)) # print('supervised_loss:{}'.format(loss)) loss += self_supervised_loss if self.batch_num == 327: print('{} loss:{}'.format(self.batch_num, loss)) exit() # exit() return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) result = {'pred': pred} if self.self_supervised: chars_pred = self.output_self_supervised(encoded) result['chars_pred'] = chars_pred return result
def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e, target, chars_target=None): # if self.training: # self.batch_num+=1 # if self.batch_num == 1000: # exit() # print('lattice:') # print(lattice) if self.mode['debug']: print('lattice:{}'.format(lattice)) print('bigrams:{}'.format(bigrams)) print('seq_len:{}'.format(seq_len)) print('lex_num:{}'.format(lex_num)) print('pos_s:{}'.format(pos_s)) print('pos_e:{}'.format(pos_e)) batch_size = lattice.size(0) max_seq_len_and_lex_num = lattice.size(1) max_seq_len = bigrams.size(1) raw_embed = self.lattice_embed(lattice) # raw_embed 是字和词的pretrain的embedding,但是是分别trian的,所以需要区分对待 if self.use_bigram: bigrams_embed = self.bigram_embed(bigrams) bigrams_embed = torch.cat([ bigrams_embed, torch.zeros(size=[ batch_size, max_seq_len_and_lex_num - max_seq_len, self.bigram_size ]).to(bigrams_embed) ], dim=1) raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1) else: raw_embed_char = raw_embed dim2 = 0 dim3 = 2 # print('raw_embed:{}'.format(raw_embed[:,dim2,:dim3])) # print('raw_embed_char:{}'.format(raw_embed_char[:, dim2, :dim3])) if self.embed_dropout_pos == '0': raw_embed_char = self.embed_dropout(raw_embed_char) raw_embed = self.gaz_dropout(raw_embed) # print('raw_embed_dropout:{}'.format(raw_embed[:,dim2,:dim3])) # print('raw_embed_char_dropout:{}'.format(raw_embed_char[:, dim2, :dim3])) embed_char = self.char_proj(raw_embed_char) if self.mode['debug']: print('embed_char:{}'.format(embed_char[:2])) char_mask = seq_len_to_mask(seq_len, max_len=max_seq_len_and_lex_num).bool() # if self.embed_dropout_pos == '1': # embed_char = self.embed_dropout(embed_char) embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0) embed_lex = self.lex_proj(raw_embed) if self.mode['debug']: print('embed_lex:{}'.format(embed_lex[:2])) # if self.embed_dropout_pos == '1': # embed_lex = self.embed_dropout(embed_lex) lex_mask = (seq_len_to_mask(seq_len + lex_num).bool() ^ char_mask.bool()) embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0) assert char_mask.size(1) == lex_mask.size(1) embedding = embed_char + embed_lex if self.mode['debug']: print('embedding:{}'.format(embedding[:2])) if self.embed_dropout_pos == '1': embedding = self.embed_dropout(embedding) if self.use_abs_pos: embedding = self.abs_pos_encode(embedding, pos_s, pos_e) if self.embed_dropout_pos == '2': embedding = self.embed_dropout(embedding) # embedding = self.embed_dropout(embedding) # print('embedding:{}'.format(embedding[:,dim2,:dim3])) if self.batch_num == 327: print('{} embed:{}'.format(self.batch_num, embedding[:2, dim2, :dim3])) encoded = self.encoder(embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e, print_=(self.batch_num == 327)) if self.batch_num == 327: print('{} encoded:{}'.format(self.batch_num, encoded[:2, dim2, :dim3])) if hasattr(self, 'output_dropout'): encoded = self.output_dropout(encoded) encoded = encoded[:, :max_seq_len, :] pred = self.output(encoded) if self.batch_num == 327: print('{} pred:{}'.format(self.batch_num, pred[:2, dim2, :dim3])) # print('pred:{}'.format(pred[:,dim2,:dim3])) # exit() mask = seq_len_to_mask(seq_len).bool() if self.mode['debug']: print('debug mode:finish!') exit(1208) if self.training: loss = self.crf(pred, target, mask).mean(dim=0) if self.self_supervised: # print('self supervised loss added!') chars_pred = self.output_self_supervised(encoded) chars_pred = chars_pred.view( size=[batch_size * max_seq_len, -1]) chars_target = chars_target.view( size=[batch_size * max_seq_len]) self_supervised_loss = self.loss_func(chars_pred, chars_target) # print('self_supervised_loss:{}'.format(self_supervised_loss)) # print('supervised_loss:{}'.format(loss)) loss += self_supervised_loss if self.batch_num == 327: print('{} loss:{}'.format(self.batch_num, loss)) exit() # exit() return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) result = {'pred': pred} if self.self_supervised: chars_pred = self.output_self_supervised(encoded) result['chars_pred'] = chars_pred return result