示例#1
0
 def test_create_look_ahead_mask(self):
     length = 5
     mask = helper.create_look_ahead_mask(length)
     expected_mask = torch.Tensor(
         [[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 0, 1, 1],
          [0, 0, 0, 0, 1], [0, 0, 0, 0, 0]], )
     expected_mask = expected_mask[None, None, :, :]
     self.assertEqual((mask != expected_mask).sum(), 0)
示例#2
0
    def test_forward_with_mask(self):
        decoder_layer = DecoderLayerCustomized(d_model=512, num_heads=8)
        word_input = torch.randn(size=(5, 100, 512))
        look_ahead_mask = helper.create_look_ahead_mask(100)
        length = torch.ones(size=(5,))*10
        target_padding_mask = helper.create_padding_mask(length, max_len=100)

        output, _ = decoder_layer(word_input, look_ahead_mask=look_ahead_mask, target_padding_mask=target_padding_mask)
        self.assertListEqual(list(output.size()), [5, 100, 512])
示例#3
0
        def core_get_logits(x, seq_len):
            assert x[0][0] == self.model.bos_id, self.model.bos_id

            x = self.model.word_embedding(x)
            dec_out = x
            look_ahead_mask = helper.create_look_ahead_mask(x.size(1)).to(
                x.device)
            padding_mask = helper.create_padding_mask(seq_len, x.size(1))
            for i in range(len(self.model.decoder)):
                dec_out, _ = self.model.decoder[i](dec_out, look_ahead_mask,
                                                   padding_mask)
            return dec_out
示例#4
0
    def forward(self, source_words, target_words, source_length, target_length, *args):
        """

        :param source_words: (batch, max_seq_len_1)
        :param target_words: (batch, max_seq_len_2)
        :param source_length:  (batch)
        :param target_length:  (batch)
        :param args:
        :return: logits (batch, max_seq_len_2, target_vocab_size)
        """
        source_padding = helper.create_padding_mask(source_seq_len=source_length, max_len=source_words.size(-1))
        encoding_output = self.encoder(source_words, source_padding)
        look_ahead_mask = helper.create_look_ahead_mask(target_words.size(-1))
        decoding_output = self.decoder(target_words, encoding_output, look_ahead_mask, source_padding)
        output = self.final_layer(decoding_output)
        return output
示例#5
0
    def test_forward_with_masks(self):
        word_embedding = nn.Embedding(1000, 512)
        decoder = Decoder(num_layers=6,
                          d_model=512,
                          num_heads=8,
                          word_embedding=word_embedding)

        word_input = torch.randint(1000, size=(5, 80))
        source_seq_len = torch.randint(100, size=(5, ))
        padding_mask = helper.create_padding_mask(source_seq_len, max_len=100)
        encoder_output = torch.randn(size=(5, 100, 512))
        look_ahead_mask = helper.create_look_ahead_mask(80)
        output = decoder(word_input,
                         encoder_output,
                         look_ahead_mask=look_ahead_mask,
                         source_padding_mask=padding_mask,
                         target_padding_mask=None)
        self.assertListEqual(list(output.size()), [5, 80, 512])
示例#6
0
    def get_logits(self, x, seq_len):
        """

        :param x: (batch, seq_len)
        :param seq_len: (batch)
        :param args:
        :return:
        """
        assert x[0][0] == self.bos_id, self.bos_id

        x = self.word_embedding(x)
        dec_out = x
        look_ahead_mask = helper.create_look_ahead_mask(x.size(1)).to(x.device)
        padding_mask = helper.create_padding_mask(seq_len, x.size(1))
        for i in range(len(self.decoder)):
            dec_out, _ = self.decoder[i](dec_out, look_ahead_mask,
                                         padding_mask)

        logits = self.lm_layer.get_logits(dec_out)
        return logits