Exemplo n.º 1
0
    def forward(self, lattice: torch.Tensor, bigrams: torch.Tensor,
                seq_len: torch.Tensor, lex_num: torch.Tensor,
                pos_s: torch.Tensor, pos_e: torch.Tensor,
                target: Optional[torch.Tensor]):
        batch_size = lattice.size(0)
        max_seq_len_and_lex_num = lattice.size(1)
        max_seq_len = bigrams.size(1)

        raw_embed = self.lattice_embed(lattice)
        bigrams_embed = self.bigram_embed(bigrams)
        bigrams_embed = torch.cat([
            bigrams_embed,
            torch.zeros(size=[
                batch_size, max_seq_len_and_lex_num -
                max_seq_len, self.bigram_size
            ]).to(bigrams_embed)
        ],
                                  dim=1)
        raw_embed_char = torch.cat([raw_embed, bigrams_embed], dim=-1)

        raw_embed_char = self.embed_dropout(raw_embed_char)
        raw_embed = self.gaz_dropout(raw_embed)

        embed_char = self.char_proj(raw_embed_char)
        char_mask = seq_len_to_mask(seq_len, max_len=max_seq_len_and_lex_num)
        embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0)

        embed_lex = self.lex_proj(raw_embed)
        lex_mask = (seq_len_to_mask(seq_len + lex_num) ^ char_mask)
        embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0)

        embedding = embed_char + embed_lex
        encoded = self.encoder(embedding,
                               seq_len,
                               lex_num=lex_num,
                               pos_s=pos_s,
                               pos_e=pos_e)
        encoded = self.output_dropout(encoded)

        # 这里只获取transformer输出的char部分
        encoded = encoded[:, :max_seq_len, :]
        pred = self.output(encoded)
        mask = seq_len_to_mask(seq_len)

        # script使用
        # pred, path = self.crf.viterbi_decode(pred, mask)
        # return pred

        if self.training:
            loss = self.crf(pred, target, mask).mean(dim=0)
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            result = {'pred': pred}
            return result
Exemplo n.º 2
0
    def forward(self, x):
        word, lens, head_pos, tail_pos = x['word'], x['lens'], x[
            'head_pos'], x['tail_pos']
        mask = seq_len_to_mask(lens)

        inputs = self.embedding(word, head_pos, tail_pos)
        out, out_pool = self.cnn(inputs, mask=mask)

        if self.use_pcnn:
            out = out.unsqueeze(-1)  # [B, L, Hs, 1]
            pcnn_mask = x['pcnn_mask']
            pcnn_mask = self.pcnn_mask_embedding(pcnn_mask).unsqueeze(
                -2)  # [B, L, 1, 3]
            out = out + pcnn_mask  # [B, L, Hs, 3]
            out = out.max(dim=1)[0] - 100  # [B, Hs, 3]
            out_pool = out.view(out.size(0), -1)  # [B, 3 * Hs]
            out_pool = F.leaky_relu(self.fc_pcnn(out_pool))  # [B, Hs]
            out_pool = self.dropout(out_pool)

        output = self.fc1(out_pool)
        output = F.leaky_relu(output)
        output = self.dropout(output)
        output = self.fc2(output)

        return output
Exemplo n.º 3
0
    def forward(self, x):
        word, lens = x['word'], x['lens']
        mask = seq_len_to_mask(lens, mask_pos_to_true=False)
        last_hidden_state, pooler_output = self.bert(word, attention_mask=mask)
        out, out_pool = self.bilstm(last_hidden_state, lens)
        out_pool = self.dropout(out_pool)
        output = self.fc(out_pool)

        return output
Exemplo n.º 4
0
    def forward(self, x):
        word, lens, head_pos, tail_pos = x['word'], x['lens'], x[
            'head_pos'], x['tail_pos']
        mask = seq_len_to_mask(lens)
        inputs = self.embedding(word, head_pos, tail_pos)
        last_layer_hidden_state, all_hidden_states, all_attentions = self.transformer(
            inputs, key_padding_mask=mask)
        out_pool = last_layer_hidden_state.max(dim=1)[0]
        output = self.fc(out_pool)

        return output
Exemplo n.º 5
0
def test_CNN():

    x = torch.randn(4, 5, 100)
    seq = torch.arange(4, 0, -1)
    mask = seq_len_to_mask(seq, max_len=5)

    cnn = CNN(config)
    out, out_pooling = cnn(x, mask=mask)
    out_channels = config.out_channels * len(config.kernel_sizes)
    assert out.shape == torch.Size([4, 5, out_channels])
    assert out_pooling.shape == torch.Size([4, out_channels])
Exemplo n.º 6
0
    def forward(self, x):
        word, lens, head_pos, tail_pos = x['word'], x['lens'], x[
            'head_pos'], x['tail_pos']
        mask = seq_len_to_mask(lens)
        inputs = self.embedding(word, head_pos, tail_pos)

        primary, _ = self.cnn(
            inputs)  # 由于长度改变,无法定向mask,不mask可可以,毕竟primary capsule 就是粗粒度的信息
        output = self.capsule(primary)
        output = output.norm(p=2, dim=-1)  # 求得模长再返回值

        return output  # [B, N]
Exemplo n.º 7
0
def test_Transformer():
    m = Transformer(config)
    i = torch.randn(4, 5, 12)  # [B, L, H]
    key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
    attention_mask = torch.tensor([1, 0, 0, 1, 0])  # 为1 的地方 mask 掉
    head_mask = torch.tensor([0, 1, 0])  # 为1 的地方 mask 掉

    out = m(i,
            key_padding_mask=key_padding_mask,
            attention_mask=attention_mask,
            head_mask=head_mask)
    hn, h_all, att_weights = out
    assert hn.shape == torch.Size([4, 5, 12])
    assert torch.equal(h_all[0], i) and torch.equal(h_all[-1], hn) == True
    assert len(h_all) == config.num_hidden_layers + 1
    assert len(att_weights) == config.num_hidden_layers
    assert att_weights[0].shape == torch.Size([4, 3, 5, 5])
    assert att_weights[0].unbind(dim=1)[1].bool().any() == False
Exemplo n.º 8
0
import pytest
import torch
from utils import seq_len_to_mask
from module import DotAttention, MultiHeadAttention

torch.manual_seed(1)
q = torch.randn(4, 6, 20)  # [B, L, H]
k = v = torch.randn(4, 5, 20)  # [B, S, H]
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
attention_mask = torch.tensor([1, 0, 0, 1, 0])  # 为1 的地方 mask 掉
head_mask = torch.tensor([0, 1, 0, 0])  # 为1 的地方 mask 掉

# m = DotAttention(dropout=0.0)
# ao,aw = m(q,k,v,key_padding_mask)
# print(ao.shape,aw.shape)
# print(aw)


def test_DotAttention():
    m = DotAttention(dropout=0.0)
    ao, aw = m(q, k, v, mask_out=key_padding_mask)

    assert ao.shape == torch.Size([4, 6, 20])
    assert aw.shape == torch.Size([4, 6, 5])
    assert torch.all(aw[1, :, -1:].eq(0)) == torch.all(
        aw[2, :, -2:].eq(0)) == torch.all(aw[3, :, -3:].eq(0)) == True


def test_MultiHeadAttention():
    m = MultiHeadAttention(embed_dim=20, num_heads=4, dropout=0.0)
    ao, aw = m(q,
Exemplo n.º 9
0
    def forward(self, key, query, value, seq_len, lex_num, rel_pos_embedding):
        batch = key.size(0)

        key = self.w_k(key)
        query = self.w_q(query)
        value = self.w_v(value)
        rel_pos_embedding = self.w_r(rel_pos_embedding)

        batch = key.size(0)
        max_seq_len = key.size(1)

        # batch * seq_len * n_head * d_head
        key = torch.reshape(
            key, [batch, max_seq_len, self.num_heads, self.per_head_size])
        query = torch.reshape(
            query, [batch, max_seq_len, self.num_heads, self.per_head_size])
        value = torch.reshape(
            value, [batch, max_seq_len, self.num_heads, self.per_head_size])
        # batch * seq_len * seq_len * n_head * d_head
        rel_pos_embedding = torch.reshape(rel_pos_embedding, [
            batch, max_seq_len, max_seq_len, self.num_heads, self.per_head_size
        ])

        # batch * n_head * seq_len * d_head
        key = key.transpose(1, 2)
        query = query.transpose(1, 2)
        value = value.transpose(1, 2)

        # batch * n_head * d_head * seq_len
        key = key.transpose(-1, -2)

        # u_for_c: 1(batch broadcast) * n_head * 1(seq_len) * d_head
        # u_for_c = self.u.unsqueeze(0).unsqueeze(-2)

        # query_and_u_for_c = query + u_for_c
        query_and_u_for_c = query
        # query_and_u_for_c: batch * n_head * seq_len * d_head
        # key: batch * n_head * d_head * seq_len
        A_C = torch.matmul(query_and_u_for_c, key)
        # after above, A_C: batch * n_head * seq_len * seq_len

        # rel_pos_embedding_for_b: batch * num_head * query_len * per_head_size * key_len
        rel_pos_embedding_for_b = rel_pos_embedding.permute(0, 3, 1, 4, 2)

        query_for_b = query.view(
            [batch, self.num_heads, max_seq_len, 1, self.per_head_size])
        # query_for_b_and_v_for_d = query_for_b + self.v.view(1, self.num_heads, 1, 1, self.per_head_size)
        query_for_b_and_v_for_d = query_for_b
        # after above, query_for_b_and_v_for_d: batch * num_head * seq_len * 1 * d_head

        B_D = torch.matmul(query_for_b_and_v_for_d,
                           rel_pos_embedding_for_b).squeeze(-2)
        # after above, B_D: batch * n_head * seq_len * key_len
        attn_score_raw = A_C + B_D

        # 后续会对transformer的输出做截断,只选取char部分的输出
        mask = seq_len_to_mask(seq_len + lex_num).unsqueeze(1).unsqueeze(1)
        # mask = seq_len_to_mask(seq_len + lex_num).bool().unsqueeze(1).unsqueeze(1)
        attn_score_raw_masked = attn_score_raw.masked_fill(~mask, -1e15)
        attn_score = F.softmax(attn_score_raw_masked, dim=-1)
        attn_score = self.dropout(attn_score)
        # attn_score: batch * n_head * seq_len * key_len
        # value: batch * n_head * seq_len * d_head
        value_weighted_sum = torch.matmul(attn_score, value)
        # after above, value_weighted_sum: batch * n_head * seq_len * d_head

        result = value_weighted_sum.transpose(1, 2).contiguous().reshape(
            batch, max_seq_len, self.hidden_size)
        # after above, result: batch * seq_len * hidden_size (hidden_size=n_head * d_head)

        return result