예제 #1
0
    def create_mask(self, bth, inputs):
        T = bth.shape[1]
        mask = subsequent_mask(T).type_as(bth)
        if not self.mask_pad:
            return mask

        return mask * self._pad_mask(inputs)
예제 #2
0
 def forward(self, encoder_output, dst):
     embed_out_bth = self.tgt_embeddings(dst)
     embed_out_bth = self.proj_to_hsz(embed_out_bth)
     context_bth = encoder_output.output
     T = embed_out_bth.shape[1]
     dst_mask = subsequent_mask(T).type_as(embed_out_bth)  # [B, 1, T_q, T_q]
     src_mask = encoder_output.src_mask.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, T_k]
     output = self.transformer_decoder((embed_out_bth, context_bth, src_mask, dst_mask))
     output = self.proj_to_dsz(output)
     prob = self.output(output)
     return prob
예제 #3
0
def test_subsequent_mask_valid_loc():
    T = np.random.randint(4, 100)
    mask = subsequent_mask(T).numpy().squeeze()

    def test(T, mask):
        i, j = np.random.randint(0, T, size=2)
        if i < j:
            assert mask[i, j] == 0
        else:
            assert mask[i, j] == 1

    for _ in range(100):
        test(T, mask)
예제 #4
0
def attn_values_sub_mask(attn, qkv):
    q, k, v = qkv
    B, H, T, _ = q.shape
    q = q.zero_()
    mask = subsequent_mask(T)
    res = attn((q, k, v, mask))
    res = res.numpy()
    gold = v.numpy()
    for b in range(B):
        for h in range(H):
            for t in range(T):
                np.testing.assert_allclose(res[b, h, t, :],
                                           np.mean(gold[:, :, :t + 1, :],
                                                   axis=2)[b, h, :],
                                           atol=1e-5)
예제 #5
0
 def create_mask(self, bth):
     T = bth.shape[1]
     mask = subsequent_mask(T).type_as(bth)
     return mask
예제 #6
0
def test_subsequent_mask_valid_count():
    T = np.random.randint(4, 50)
    gold = (T * (T + 1)) / 2
    mask = subsequent_mask(T).numpy()
    assert np.sum(mask) == gold
예제 #7
0
def test_subsequent_mask_shape():
    T = np.random.randint(2, 50)
    gold = (1, 1, T, T)
    mask = subsequent_mask(T)
    assert mask.shape == gold