def _model_mask(self, nctx):
     """This function creates the mask that controls which token to be attended to depending on the model. A causal
     LM should have a subsequent mask; and a masked LM should have no mask."""
     if self.mlm:
         return torch.ones((1, 1, nctx, nctx), dtype=torch.long)
     else:
         return subsequent_mask(nctx)
Example #2
0
 def forward(self, encoder_output, dst):
     embed_out_bth = self.tgt_embeddings(dst)
     embed_out_bth = self.proj_to_hsz(embed_out_bth)
     context_bth = encoder_output.output
     T = embed_out_bth.shape[1]
     dst_mask = subsequent_mask(T).type_as(embed_out_bth)
     src_mask = encoder_output.src_mask.unsqueeze(1).unsqueeze(1)
     output = self.transformer_decoder(embed_out_bth, context_bth, src_mask, dst_mask)
     output = self.proj_to_dsz(output)
     prob = self.output(output)
     return prob
Example #3
0
 def forward(self, encoder_output, dst):
     embed_out_bth = self.tgt_embeddings(dst)
     embed_out_bth = self.proj_to_hsz(embed_out_bth)
     context_bth = encoder_output.output
     T = embed_out_bth.shape[1]
     dst_mask = subsequent_mask(T).type_as(embed_out_bth)
     src_mask = encoder_output.src_mask.unsqueeze(1).unsqueeze(1)
     output = self.transformer_decoder(embed_out_bth, context_bth, src_mask, dst_mask)
     output = self.proj_to_dsz(output)
     prob = self.output(output)
     return prob
Example #4
0
def test_subsequent_mask_valid_loc():
    T = np.random.randint(4, 100)
    mask = subsequent_mask(T).numpy().squeeze()

    def test(T, mask):
        i, j = np.random.randint(0, T, size=2)
        if i < j:
            assert mask[i, j] == 0
        else:
            assert mask[i, j] == 1

    for _ in range(100):
        test(T, mask)
Example #5
0
def test_subsequent_mask_valid_loc():
    T = np.random.randint(4, 100)
    mask = subsequent_mask(T).numpy().squeeze()

    def test(T, mask):
        i, j = np.random.randint(0, T, size=2)
        if i < j:
            assert mask[i, j] == 0
        else:
            assert mask[i, j] == 1

    for _ in range(100):
        test(T, mask)
def attn_values_sub_mask(attn, qkv):
    q, k, v = qkv
    B, H, T, _ = q.shape
    q = q.zero_()
    mask = subsequent_mask(T)
    res, _ = attn(q, k, v, mask=mask)
    res = res.numpy()
    gold = v.numpy()
    for b in range(B):
        for h in range(H):
            for t in range(T):
                np.testing.assert_allclose(res[b, h, t, :],
                                           np.mean(gold[:, :, :t + 1, :],
                                                   axis=2)[b, h, :],
                                           atol=1e-5)
Example #7
0
 def decode(self, bth, hidden):
     bth = self.proj_to_dsz(bth)
     T = bth.shape[1]
     mask = subsequent_mask(T).type_as(bth)
     return self.transformer(bth, mask), None
Example #8
0
def test_subsequent_mask_valid_count():
    T = np.random.randint(4, 50)
    gold = (T * (T + 1)) / 2
    mask = subsequent_mask(T).numpy()
    assert np.sum(mask) == gold
Example #9
0
def test_subsequent_mask_shape():
    T = np.random.randint(2, 50)
    gold = (1, 1, T, T)
    mask = subsequent_mask(T)
    assert mask.shape == gold
Example #10
0
def test_subsequent_mask_valid_count():
    T = np.random.randint(4, 50)
    gold = (T * (T + 1)) / 2
    mask = subsequent_mask(T).numpy()
    assert np.sum(mask) == gold
Example #11
0
def test_subsequent_mask_shape():
    T = np.random.randint(2, 50)
    gold = (1, 1, T, T)
    mask = subsequent_mask(T)
    assert mask.shape == gold
Example #12
0
 def create_mask(self, bth):
     bth = self.proj_to_dsz(bth)
     T = bth.shape[1]
     mask = subsequent_mask(T).type_as(bth)
     return mask
Example #13
0
 def create_mask(self, bth):
     T = bth.shape[1]
     mask = subsequent_mask(T).type_as(bth)
     return mask
Example #14
0
 def decode(self, bth, hidden):
     bth = self.proj_to_dsz(bth)
     T = bth.shape[1]
     mask = subsequent_mask(T).type_as(bth)
     return self.transformer(bth, mask), None