def _model_mask(self, nctx): """This function creates the mask that controls which token to be attended to depending on the model. A causal LM should have a subsequent mask; and a masked LM should have no mask.""" if self.mlm: return torch.ones((1, 1, nctx, nctx), dtype=torch.long) else: return subsequent_mask(nctx)
def forward(self, encoder_output, dst): embed_out_bth = self.tgt_embeddings(dst) embed_out_bth = self.proj_to_hsz(embed_out_bth) context_bth = encoder_output.output T = embed_out_bth.shape[1] dst_mask = subsequent_mask(T).type_as(embed_out_bth) src_mask = encoder_output.src_mask.unsqueeze(1).unsqueeze(1) output = self.transformer_decoder(embed_out_bth, context_bth, src_mask, dst_mask) output = self.proj_to_dsz(output) prob = self.output(output) return prob
def forward(self, encoder_output, dst): embed_out_bth = self.tgt_embeddings(dst) embed_out_bth = self.proj_to_hsz(embed_out_bth) context_bth = encoder_output.output T = embed_out_bth.shape[1] dst_mask = subsequent_mask(T).type_as(embed_out_bth) src_mask = encoder_output.src_mask.unsqueeze(1).unsqueeze(1) output = self.transformer_decoder(embed_out_bth, context_bth, src_mask, dst_mask) output = self.proj_to_dsz(output) prob = self.output(output) return prob
def test_subsequent_mask_valid_loc(): T = np.random.randint(4, 100) mask = subsequent_mask(T).numpy().squeeze() def test(T, mask): i, j = np.random.randint(0, T, size=2) if i < j: assert mask[i, j] == 0 else: assert mask[i, j] == 1 for _ in range(100): test(T, mask)
def test_subsequent_mask_valid_loc(): T = np.random.randint(4, 100) mask = subsequent_mask(T).numpy().squeeze() def test(T, mask): i, j = np.random.randint(0, T, size=2) if i < j: assert mask[i, j] == 0 else: assert mask[i, j] == 1 for _ in range(100): test(T, mask)
def attn_values_sub_mask(attn, qkv): q, k, v = qkv B, H, T, _ = q.shape q = q.zero_() mask = subsequent_mask(T) res, _ = attn(q, k, v, mask=mask) res = res.numpy() gold = v.numpy() for b in range(B): for h in range(H): for t in range(T): np.testing.assert_allclose(res[b, h, t, :], np.mean(gold[:, :, :t + 1, :], axis=2)[b, h, :], atol=1e-5)
def decode(self, bth, hidden): bth = self.proj_to_dsz(bth) T = bth.shape[1] mask = subsequent_mask(T).type_as(bth) return self.transformer(bth, mask), None
def test_subsequent_mask_valid_count(): T = np.random.randint(4, 50) gold = (T * (T + 1)) / 2 mask = subsequent_mask(T).numpy() assert np.sum(mask) == gold
def test_subsequent_mask_shape(): T = np.random.randint(2, 50) gold = (1, 1, T, T) mask = subsequent_mask(T) assert mask.shape == gold
def test_subsequent_mask_valid_count(): T = np.random.randint(4, 50) gold = (T * (T + 1)) / 2 mask = subsequent_mask(T).numpy() assert np.sum(mask) == gold
def test_subsequent_mask_shape(): T = np.random.randint(2, 50) gold = (1, 1, T, T) mask = subsequent_mask(T) assert mask.shape == gold
def create_mask(self, bth): bth = self.proj_to_dsz(bth) T = bth.shape[1] mask = subsequent_mask(T).type_as(bth) return mask
def create_mask(self, bth): T = bth.shape[1] mask = subsequent_mask(T).type_as(bth) return mask
def decode(self, bth, hidden): bth = self.proj_to_dsz(bth) T = bth.shape[1] mask = subsequent_mask(T).type_as(bth) return self.transformer(bth, mask), None