def test_create_look_ahead_mask(self): length = 5 mask = helper.create_look_ahead_mask(length) expected_mask = torch.Tensor( [[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0]], ) expected_mask = expected_mask[None, None, :, :] self.assertEqual((mask != expected_mask).sum(), 0)
def test_forward_with_mask(self): decoder_layer = DecoderLayerCustomized(d_model=512, num_heads=8) word_input = torch.randn(size=(5, 100, 512)) look_ahead_mask = helper.create_look_ahead_mask(100) length = torch.ones(size=(5,))*10 target_padding_mask = helper.create_padding_mask(length, max_len=100) output, _ = decoder_layer(word_input, look_ahead_mask=look_ahead_mask, target_padding_mask=target_padding_mask) self.assertListEqual(list(output.size()), [5, 100, 512])
def core_get_logits(x, seq_len): assert x[0][0] == self.model.bos_id, self.model.bos_id x = self.model.word_embedding(x) dec_out = x look_ahead_mask = helper.create_look_ahead_mask(x.size(1)).to( x.device) padding_mask = helper.create_padding_mask(seq_len, x.size(1)) for i in range(len(self.model.decoder)): dec_out, _ = self.model.decoder[i](dec_out, look_ahead_mask, padding_mask) return dec_out
def forward(self, source_words, target_words, source_length, target_length, *args): """ :param source_words: (batch, max_seq_len_1) :param target_words: (batch, max_seq_len_2) :param source_length: (batch) :param target_length: (batch) :param args: :return: logits (batch, max_seq_len_2, target_vocab_size) """ source_padding = helper.create_padding_mask(source_seq_len=source_length, max_len=source_words.size(-1)) encoding_output = self.encoder(source_words, source_padding) look_ahead_mask = helper.create_look_ahead_mask(target_words.size(-1)) decoding_output = self.decoder(target_words, encoding_output, look_ahead_mask, source_padding) output = self.final_layer(decoding_output) return output
def test_forward_with_masks(self): word_embedding = nn.Embedding(1000, 512) decoder = Decoder(num_layers=6, d_model=512, num_heads=8, word_embedding=word_embedding) word_input = torch.randint(1000, size=(5, 80)) source_seq_len = torch.randint(100, size=(5, )) padding_mask = helper.create_padding_mask(source_seq_len, max_len=100) encoder_output = torch.randn(size=(5, 100, 512)) look_ahead_mask = helper.create_look_ahead_mask(80) output = decoder(word_input, encoder_output, look_ahead_mask=look_ahead_mask, source_padding_mask=padding_mask, target_padding_mask=None) self.assertListEqual(list(output.size()), [5, 80, 512])
def get_logits(self, x, seq_len): """ :param x: (batch, seq_len) :param seq_len: (batch) :param args: :return: """ assert x[0][0] == self.bos_id, self.bos_id x = self.word_embedding(x) dec_out = x look_ahead_mask = helper.create_look_ahead_mask(x.size(1)).to(x.device) padding_mask = helper.create_padding_mask(seq_len, x.size(1)) for i in range(len(self.decoder)): dec_out, _ = self.decoder[i](dec_out, look_ahead_mask, padding_mask) logits = self.lm_layer.get_logits(dec_out) return logits