def forward(self, lattice: torch.Tensor, bigrams: torch.Tensor, seq_len: torch.Tensor, lex_num: torch.Tensor, pos_s: torch.Tensor, pos_e: torch.Tensor, target: Optional[torch.Tensor]): batch_size = lattice.size(0) max_seq_len_and_lex_num = lattice.size(1) max_seq_len = bigrams.size(1) raw_embed = self.lattice_embed(lattice) bigrams_embed = self.bigram_embed(bigrams) bigrams_embed = torch.cat([ bigrams_embed, torch.zeros(size=[ batch_size, max_seq_len_and_lex_num - max_seq_len, self.bigram_size ]).to(bigrams_embed) ], dim=1) raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1) raw_embed_char = self.embed_dropout(raw_embed_char) raw_embed = self.gaz_dropout(raw_embed) embed_char = self.char_proj(raw_embed_char) char_mask = seq_len_to_mask(seq_len, max_len=max_seq_len_and_lex_num) embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0) embed_lex = self.lex_proj(raw_embed) lex_mask = (seq_len_to_mask(seq_len + lex_num) ^ char_mask) embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0) embedding = embed_char + embed_lex encoded = self.encoder(embedding, seq_len, lex_num=lex_num, pos_s=pos_s, pos_e=pos_e) encoded = self.output_dropout(encoded) # 这里只获取transformer输出的char部分 encoded = encoded[:, :max_seq_len, :] pred = self.output(encoded) mask = seq_len_to_mask(seq_len) # script使用 # pred, path = self.crf.viterbi_decode(pred, mask) # return pred if self.training: loss = self.crf(pred, target, mask).mean(dim=0) return {'loss': loss} else: pred, path = self.crf.viterbi_decode(pred, mask) result = {'pred': pred} return result
def forward(self, x): word, lens, head_pos, tail_pos = x['word'], x['lens'], x[ 'head_pos'], x['tail_pos'] mask = seq_len_to_mask(lens) inputs = self.embedding(word, head_pos, tail_pos) out, out_pool = self.cnn(inputs, mask=mask) if self.use_pcnn: out = out.unsqueeze(-1) # [B, L, Hs, 1] pcnn_mask = x['pcnn_mask'] pcnn_mask = self.pcnn_mask_embedding(pcnn_mask).unsqueeze( -2) # [B, L, 1, 3] out = out + pcnn_mask # [B, L, Hs, 3] out = out.max(dim=1)[0] - 100 # [B, Hs, 3] out_pool = out.view(out.size(0), -1) # [B, 3 * Hs] out_pool = F.leaky_relu(self.fc_pcnn(out_pool)) # [B, Hs] out_pool = self.dropout(out_pool) output = self.fc1(out_pool) output = F.leaky_relu(output) output = self.dropout(output) output = self.fc2(output) return output
def forward(self, x): word, lens = x['word'], x['lens'] mask = seq_len_to_mask(lens, mask_pos_to_true=False) last_hidden_state, pooler_output = self.bert(word, attention_mask=mask) out, out_pool = self.bilstm(last_hidden_state, lens) out_pool = self.dropout(out_pool) output = self.fc(out_pool) return output
def forward(self, x): word, lens, head_pos, tail_pos = x['word'], x['lens'], x[ 'head_pos'], x['tail_pos'] mask = seq_len_to_mask(lens) inputs = self.embedding(word, head_pos, tail_pos) last_layer_hidden_state, all_hidden_states, all_attentions = self.transformer( inputs, key_padding_mask=mask) out_pool = last_layer_hidden_state.max(dim=1)[0] output = self.fc(out_pool) return output
def test_CNN(): x = torch.randn(4, 5, 100) seq = torch.arange(4, 0, -1) mask = seq_len_to_mask(seq, max_len=5) cnn = CNN(config) out, out_pooling = cnn(x, mask=mask) out_channels = config.out_channels * len(config.kernel_sizes) assert out.shape == torch.Size([4, 5, out_channels]) assert out_pooling.shape == torch.Size([4, out_channels])
def forward(self, x): word, lens, head_pos, tail_pos = x['word'], x['lens'], x[ 'head_pos'], x['tail_pos'] mask = seq_len_to_mask(lens) inputs = self.embedding(word, head_pos, tail_pos) primary, _ = self.cnn( inputs) # 由于长度改变,无法定向mask,不mask可可以,毕竟primary capsule 就是粗粒度的信息 output = self.capsule(primary) output = output.norm(p=2, dim=-1) # 求得模长再返回值 return output # [B, N]
def test_Transformer(): m = Transformer(config) i = torch.randn(4, 5, 12) # [B, L, H] key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5) attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉 head_mask = torch.tensor([0, 1, 0]) # 为1 的地方 mask 掉 out = m(i, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask) hn, h_all, att_weights = out assert hn.shape == torch.Size([4, 5, 12]) assert torch.equal(h_all[0], i) and torch.equal(h_all[-1], hn) == True assert len(h_all) == config.num_hidden_layers + 1 assert len(att_weights) == config.num_hidden_layers assert att_weights[0].shape == torch.Size([4, 3, 5, 5]) assert att_weights[0].unbind(dim=1)[1].bool().any() == False
import pytest import torch from utils import seq_len_to_mask from module import DotAttention, MultiHeadAttention torch.manual_seed(1) q = torch.randn(4, 6, 20) # [B, L, H] k = v = torch.randn(4, 5, 20) # [B, S, H] key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5) attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉 head_mask = torch.tensor([0, 1, 0, 0]) # 为1 的地方 mask 掉 # m = DotAttention(dropout=0.0) # ao,aw = m(q,k,v,key_padding_mask) # print(ao.shape,aw.shape) # print(aw) def test_DotAttention(): m = DotAttention(dropout=0.0) ao, aw = m(q, k, v, mask_out=key_padding_mask) assert ao.shape == torch.Size([4, 6, 20]) assert aw.shape == torch.Size([4, 6, 5]) assert torch.all(aw[1, :, -1:].eq(0)) == torch.all( aw[2, :, -2:].eq(0)) == torch.all(aw[3, :, -3:].eq(0)) == True def test_MultiHeadAttention(): m = MultiHeadAttention(embed_dim=20, num_heads=4, dropout=0.0) ao, aw = m(q,
def forward(self, key, query, value, seq_len, lex_num, rel_pos_embedding): batch = key.size(0) key = self.w_k(key) query = self.w_q(query) value = self.w_v(value) rel_pos_embedding = self.w_r(rel_pos_embedding) batch = key.size(0) max_seq_len = key.size(1) # batch * seq_len * n_head * d_head key = torch.reshape( key, [batch, max_seq_len, self.num_heads, self.per_head_size]) query = torch.reshape( query, [batch, max_seq_len, self.num_heads, self.per_head_size]) value = torch.reshape( value, [batch, max_seq_len, self.num_heads, self.per_head_size]) # batch * seq_len * seq_len * n_head * d_head rel_pos_embedding = torch.reshape(rel_pos_embedding, [ batch, max_seq_len, max_seq_len, self.num_heads, self.per_head_size ]) # batch * n_head * seq_len * d_head key = key.transpose(1, 2) query = query.transpose(1, 2) value = value.transpose(1, 2) # batch * n_head * d_head * seq_len key = key.transpose(-1, -2) # u_for_c: 1(batch broadcast) * n_head * 1(seq_len) * d_head # u_for_c = self.u.unsqueeze(0).unsqueeze(-2) # query_and_u_for_c = query + u_for_c query_and_u_for_c = query # query_and_u_for_c: batch * n_head * seq_len * d_head # key: batch * n_head * d_head * seq_len A_C = torch.matmul(query_and_u_for_c, key) # after above, A_C: batch * n_head * seq_len * seq_len # rel_pos_embedding_for_b: batch * num_head * query_len * per_head_size * key_len rel_pos_embedding_for_b = rel_pos_embedding.permute(0, 3, 1, 4, 2) query_for_b = query.view( [batch, self.num_heads, max_seq_len, 1, self.per_head_size]) # query_for_b_and_v_for_d = query_for_b + self.v.view(1, self.num_heads, 1, 1, self.per_head_size) query_for_b_and_v_for_d = query_for_b # after above, query_for_b_and_v_for_d: batch * num_head * seq_len * 1 * d_head B_D = torch.matmul(query_for_b_and_v_for_d, rel_pos_embedding_for_b).squeeze(-2) # after above, B_D: batch * n_head * seq_len * key_len attn_score_raw = A_C + B_D # 后续会对transformer的输出做截断,只选取char部分的输出 mask = seq_len_to_mask(seq_len + lex_num).unsqueeze(1).unsqueeze(1) # mask = seq_len_to_mask(seq_len + lex_num).bool().unsqueeze(1).unsqueeze(1) attn_score_raw_masked = attn_score_raw.masked_fill(~mask, -1e15) attn_score = F.softmax(attn_score_raw_masked, dim=-1) attn_score = self.dropout(attn_score) # attn_score: batch * n_head * seq_len * key_len # value: batch * n_head * seq_len * d_head value_weighted_sum = torch.matmul(attn_score, value) # after above, value_weighted_sum: batch * n_head * seq_len * d_head result = value_weighted_sum.transpose(1, 2).contiguous().reshape( batch, max_seq_len, self.hidden_size) # after above, result: batch * seq_len * hidden_size (hidden_size=n_head * d_head) return result