def create_mask(self, bth, inputs): T = bth.shape[1] mask = subsequent_mask(T).type_as(bth) if not self.mask_pad: return mask return mask * self._pad_mask(inputs)
def forward(self, encoder_output, dst): embed_out_bth = self.tgt_embeddings(dst) embed_out_bth = self.proj_to_hsz(embed_out_bth) context_bth = encoder_output.output T = embed_out_bth.shape[1] dst_mask = subsequent_mask(T).type_as(embed_out_bth) # [B, 1, T_q, T_q] src_mask = encoder_output.src_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T_k] output = self.transformer_decoder((embed_out_bth, context_bth, src_mask, dst_mask)) output = self.proj_to_dsz(output) prob = self.output(output) return prob
def test_subsequent_mask_valid_loc(): T = np.random.randint(4, 100) mask = subsequent_mask(T).numpy().squeeze() def test(T, mask): i, j = np.random.randint(0, T, size=2) if i < j: assert mask[i, j] == 0 else: assert mask[i, j] == 1 for _ in range(100): test(T, mask)
def attn_values_sub_mask(attn, qkv): q, k, v = qkv B, H, T, _ = q.shape q = q.zero_() mask = subsequent_mask(T) res = attn((q, k, v, mask)) res = res.numpy() gold = v.numpy() for b in range(B): for h in range(H): for t in range(T): np.testing.assert_allclose(res[b, h, t, :], np.mean(gold[:, :, :t + 1, :], axis=2)[b, h, :], atol=1e-5)
def create_mask(self, bth): T = bth.shape[1] mask = subsequent_mask(T).type_as(bth) return mask
def test_subsequent_mask_valid_count(): T = np.random.randint(4, 50) gold = (T * (T + 1)) / 2 mask = subsequent_mask(T).numpy() assert np.sum(mask) == gold
def test_subsequent_mask_shape(): T = np.random.randint(2, 50) gold = (1, 1, T, T) mask = subsequent_mask(T) assert mask.shape == gold