def test_pytorch_seq_len(self): # 1. 随机测试 seq_len = torch.randint(1, 10, size=(10, )) max_len = seq_len.max() mask = seq_len_to_mask(seq_len) self.assertEqual(max_len, mask.shape[1]) self.evaluate_mask_seq_len(seq_len.tolist(), mask) # 2. 异常检测 seq_len = torch.randn(3, 4) with self.assertRaises(AssertionError): mask = seq_len_to_mask(seq_len)
def test_numpy_seq_len(self): # 测试能否转换numpy类型的seq_len # 1. 随机测试 seq_len = np.random.randint(1, 10, size=(10, )) mask = seq_len_to_mask(seq_len) max_len = seq_len.max() self.assertEqual(max_len, mask.shape[1]) self.evaluate_mask_seq_len(seq_len, mask) # 2. 异常检测 seq_len = np.random.randint(10, size=(10, 1)) with self.assertRaises(AssertionError): mask = seq_len_to_mask(seq_len)
def test_case3(self): # 测试crf的loss不会出现负数 import torch from fastNLP.modules.decoder.crf import ConditionalRandomField from fastNLP.core.utils import seq_len_to_mask from torch import optim from torch import nn num_tags, include_start_end_trans = 4, True num_samples = 4 lengths = torch.randint(3, 50, size=(num_samples, )).long() max_len = lengths.max() tags = torch.randint(num_tags, size=(num_samples, max_len)) masks = seq_len_to_mask(lengths) feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) crf = ConditionalRandomField(num_tags, include_start_end_trans) optimizer = optim.SGD( [param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) for _ in range(10): loss = crf(feats, tags, masks).mean() optimizer.zero_grad() loss.backward() optimizer.step() if _ % 1000 == 0: print(loss) self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.")
def forward(self, words1, words2, seq_len1, seq_len2, target=None): """ :param words1: [batch, seq_len] :param words2: [batch, seq_len] :param seq_len1: [batch] :param seq_len2: [batch] :param target: :return: """ mask1 = seq_len_to_mask(seq_len1, words1.size(1)) mask2 = seq_len_to_mask(seq_len2, words2.size(1)) a0 = self.embedding(words1) # B * len * emb_dim b0 = self.embedding(words2) a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] b = self.rnn(b0, mask2.byte()) # a = self.dropout_rnn(self.rnn(a0, seq_len1)[0]) # a: [B, PL, 2 * H] # b = self.dropout_rnn(self.rnn(b0, seq_len2)[0]) ai, bi = self.bi_attention(a, mask1, b, mask2) a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) a_f = self.interfere(a_) b_f = self.interfere(b_) a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] b_h = self.rnn_high(b_f, mask2.byte()) # a_h = self.dropout_rnn(self.rnn_high(a_f, seq_len1)[0]) # ma: [B, PL, 2 * H] # b_h = self.dropout_rnn(self.rnn_high(b_f, seq_len2)[0]) a_avg = self.mean_pooling(a_h, mask1, dim=1) a_max, _ = self.max_pooling(a_h, mask1, dim=1) b_avg = self.mean_pooling(b_h, mask2, dim=1) b_max, _ = self.max_pooling(b_h, mask2, dim=1) out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] logits = torch.tanh(self.classifier(out)) # logits = self.classifier(out) if target is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits, target) return {Const.LOSS: loss, Const.OUTPUT: logits} else: return {Const.OUTPUT: logits}
def forward(self, task_id, x, y, seq_len): tid = task_id[0].item() x = self.word_embed(x) seq_mask = seq_len_to_mask(seq_len, x.shape[1]) out, _ = self.lstm(x, seq_len) logit = self.out[tid](self.dropout(out)) loss = ce_loss(logit, y, seq_mask) pred = torch.argmax(logit, dim=2) return {"pred": pred, "loss": loss}
def forward(self, words, seq_len, target=None, chars=None): if self.char_embeddings is None: x = self.word_embeddings(words) else: if chars is None: raise ValueError( 'must provide chars for model with char embedding') e1 = self.word_embeddings(words) e2 = self.char_embeddings(chars) x = torch.cat((e1, e2), dim=-1) # b,l,h mask = seq_len_to_mask(seq_len) x = x.transpose(1, 2) # b,h,l last_output = self.conv0(x) output = [] for repeat in range(self.repeats): last_output = self.block(last_output) hidden = self.projection( last_output) if self.projection is not None else last_output output.append(self.out_fc(hidden)) def compute_loss(y, t, mask): if self.crf is not None and target is not None: loss = self.crf(y.transpose(1, 2), t, mask) else: y.masked_fill_((mask.eq(False))[:, None, :], -100) # f_mask = mask.float() # t = f_mask * t + (1-f_mask) * -100 loss = F.cross_entropy(y, t, ignore_index=-100) return loss if target is not None: if self.block_loss: losses = [compute_loss(o, target, mask) for o in output] loss = sum(losses) else: loss = compute_loss(output[-1], target, mask) else: loss = None scores = output[-1] if self.crf is not None: pred, _ = self.crf.viterbi_decode(scores.transpose(1, 2), mask) else: pred = scores.max(1)[1] * mask.long() return { C.LOSS: loss, C.OUTPUT: pred, }
def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None): """Evaluate the performance of prediction. """ if seq_lens is None: seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte) else: seq_mask = seq_len_to_mask(seq_lens.long(), float=False) # mask out <root> tag seq_mask[:, 0] = 0 head_pred_correct = (head_preds == heads).__and__(seq_mask) label_pred_correct = (label_preds == labels).__and__(head_pred_correct) self.num_arc += head_pred_correct.float().sum().item() self.num_label += label_pred_correct.float().sum().item() self.num_sample += seq_mask.sum().item()
def forward(self, task_id, x, y, seq_len): tid = task_id[0].item() word_embedding = self.word_embed(x) char_embedding = self.char_embed(x) x = torch.cat((word_embedding, char_embedding), dim=-1) seq_mask = seq_len_to_mask(seq_len, x.shape[1]) out, _ = self.lstm1(x, seq_len) if tid != 0: out, _ = self.lstm2(out, seq_len) batch_size, sent_len, _ = x.shape logit = self.out[tid](self.dropout(out)) loss = ce_loss(logit, y, seq_mask) pred = torch.argmax(logit, dim=2) return {"pred": pred, "loss": loss}
def forward(self, words, seq_len=None): r""" :param torch.LongTensor words: [batch_size, seq_len],句子中word的index :param torch.LongTensor seq_len: [batch,] 每个句子的长度 :return output: dict of torch.LongTensor, [batch_size, num_classes] """ x = self.embed(words) # [N,L] -> [N,L,C] if seq_len is not None: mask = seq_len_to_mask(seq_len) x = self.conv_pool(x, mask) else: x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.dropout(x) x = self.fc(x) # [N,C] -> [N, N_class] return {C.OUTPUT: x}
def forward(self, task_id, x, y, seq_len): words_emb = self.embedding(x) char_emb = self.char(x) x = torch.cat([words_emb, char_emb], dim=-1) x, _ = self.lstm(x, seq_len) self.dropout(x) logit = self.out[task_id[0]](x) seq_mask = seq_len_to_mask(seq_len, x.size(1)) if self.crf is not None: logit = torch.log_softmax(logit, dim=-1) loss = self.crf[task_id[0]](logit, y, seq_mask).mean() pred = self.crf[task_id[0]].viterbi_decode(logit, seq_mask)[0] else: loss = ce_loss(logit, y, seq_mask) pred = torch.argmax(logit, dim=2) return {"loss": loss, "pred": pred}
def _make_mask(self, x, seq_len): batch_size, max_len = x.size(0), x.size(1) mask = seq_len_to_mask(seq_len) mask = mask.view(batch_size, max_len) mask = mask.to(x).float() return mask
def test_get_seq_len(self): seq_len = torch.randint(1, 10, size=(10, )) mask = seq_len_to_mask(seq_len) new_seq_len = get_seq_len(mask) self.assertSequenceEqual(seq_len.tolist(), new_seq_len.tolist())
def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None): ''' :param inp: batch * seq_len * embedding, chars :param seq_len: batch, length of chars :param skip_sources: batch * seq_len * X, 跳边的起点 :param skip_words: batch * seq_len * X * embedding_size, 跳边的词 :param lexicon_count: batch * seq_len, lexicon_count[i,j]为第i个例子以第j个位子为结尾匹配到的词的数量 :param init_state: the hx of rnn :return: ''' if self.left2right: max_seq_len = max(seq_len) batch_size = inp.size(0) c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) for i in range(max_seq_len): max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) h_0, c_0 = h_[:, i, :], c_[:, i, :] #为了使rnn能够计算B*lexicon_count*embedding_size的张量,需要将其reshape成二维张量 #为了匹配pytorch的[]取址方式,需要将reshape成二维张量 skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous( ) skip_word_flat = skip_word_flat.view( batch_size * max_lexicon_count, self.word_input_size) skip_source_flat = skip_sources[:, i, : max_lexicon_count].contiguous( ).view(batch_size, max_lexicon_count) index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand( batch_size, max_lexicon_count) index_1 = skip_source_flat if not self.skip_before_head: c_x = c_[[index_0, index_1 + 1]] h_x = h_[[index_0, index_1 + 1]] else: c_x = c_[[index_0, index_1]] h_x = h_[[index_0, index_1]] c_x_flat = c_x.view(batch_size * max_lexicon_count, self.hidden_size) h_x_flat = h_x.view(batch_size * max_lexicon_count, self.hidden_size) c_1_flat = self.word_cell(skip_word_flat, (h_x_flat, c_x_flat)) c_1_skip = c_1_flat.view(batch_size, max_lexicon_count, self.hidden_size) h_1, c_1 = self.char_cell(inp[:, i, :], c_1_skip, skip_count[:, i], (h_0, c_0)) h_ = torch.cat([h_, h_1.unsqueeze(1)], dim=1) c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1) return h_[:, 1:], c_[:, 1:] else: mask_for_seq_len = seq_len_to_mask(seq_len) max_seq_len = max(seq_len) batch_size = inp.size(0) c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) for i in reversed(range(max_seq_len)): max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) h_0, c_0 = h_[:, 0, :], c_[:, 0, :] skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous( ) skip_word_flat = skip_word_flat.view( batch_size * max_lexicon_count, self.word_input_size) skip_source_flat = skip_sources[:, i, : max_lexicon_count].contiguous( ).view(batch_size, max_lexicon_count) index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand( batch_size, max_lexicon_count) index_1 = skip_source_flat - i if not self.skip_before_head: c_x = c_[[index_0, index_1 - 1]] h_x = h_[[index_0, index_1 - 1]] else: c_x = c_[[index_0, index_1]] h_x = h_[[index_0, index_1]] c_x_flat = c_x.view(batch_size * max_lexicon_count, self.hidden_size) h_x_flat = h_x.view(batch_size * max_lexicon_count, self.hidden_size) c_1_flat = self.word_cell(skip_word_flat, (h_x_flat, c_x_flat)) c_1_skip = c_1_flat.view(batch_size, max_lexicon_count, self.hidden_size) h_1, c_1 = self.char_cell(inp[:, i, :], c_1_skip, skip_count[:, i], (h_0, c_0)) h_1_mask = h_1.masked_fill( ~mask_for_seq_len[:, i].unsqueeze(-1), 0) c_1_mask = c_1.masked_fill( ~mask_for_seq_len[:, i].unsqueeze(-1), 0) h_ = torch.cat([h_1_mask.unsqueeze(1), h_], dim=1) c_ = torch.cat([c_1_mask.unsqueeze(1), c_], dim=1) return h_[:, :-1], c_[:, :-1]
def forward(self, inp, skip_c, skip_count, hx): ''' :param inp: chars B * hidden :param skip_c: 由跳边得到的c, B * X * hidden :param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask :param hx: :return: ''' max_skip_count = torch.max(skip_count).item() if True: h_0, c_0 = hx batch_size = h_0.size(0) bias_batch = (self.bias.unsqueeze(0).expand( batch_size, *self.bias.size())) wi = torch.matmul(inp, self.weight_ih) wh = torch.matmul(h_0, self.weight_hh) i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) i = torch.sigmoid(i).unsqueeze(1) o = torch.sigmoid(o).unsqueeze(1) g = torch.tanh(g).unsqueeze(1) ##basic lstm start f = 1 - i c_1_basic = f * c_0.unsqueeze(1) + i * g c_1_basic = c_1_basic.squeeze(1) alpha_wi = torch.matmul(inp, self.alpha_weight_ih) alpha_wi.unsqueeze_(1) alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh) alpha_bias_batch = self.alpha_bias.unsqueeze(0) alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) skip_mask = seq_len_to_mask(skip_count, max_len=skip_c.size()[1]).float() skip_mask = 1 - skip_mask skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size) skip_mask = (skip_mask).float() * 1e20 alpha = alpha - skip_mask alpha = torch.exp(torch.cat([i, alpha], dim=1)) alpha_sum = torch.sum(alpha, dim=1, keepdim=True) alpha = torch.div(alpha, alpha_sum) merge_i_c = torch.cat([g, skip_c], dim=1) c_1 = merge_i_c * alpha c_1 = c_1.sum(1, keepdim=True) # h_1 = o * c_1 c_1 = c_1.squeeze(1) count_select = (skip_count != 0).float().unsqueeze(-1) c_1 = c_1 * count_select + c_1_basic * (1 - count_select) o = o.squeeze(1) h_1 = o * torch.tanh(c_1) return h_1, c_1