Ejemplo n.º 1
0
    def test_trainable_variables(self):
        r"""Tests the functionality of automatically collecting trainable
        variables.
        """
        def get_variable_num(n_layers: int) -> int:
            return 1 + 1 + n_layers * 26 + 2

        # case 1: GPT2 117M
        decoder = GPT2Decoder()
        self.assertEqual(len(decoder.trainable_variables),
                         get_variable_num(12))
        _ = decoder(self.inputs)

        # case 2: GPT2 345M
        hparams = {
            "pretrained_model_name": "345M",
        }
        decoder = GPT2Decoder(hparams=hparams)
        self.assertEqual(len(decoder.trainable_variables),
                         get_variable_num(24))
        _ = decoder(self.inputs)

        # case 3: self-designed GPT2
        hparams = {
            "pretrained_model_name": None,
            "num_blocks": 6,
        }
        decoder = GPT2Decoder(hparams=hparams)
        self.assertEqual(len(decoder.trainable_variables), get_variable_num(6))
        _ = decoder(self.inputs)
Ejemplo n.º 2
0
    def test_trainable_variables(self):
        r"""Tests the functionality of automatically collecting trainable
        variables.
        """
        inputs = torch.zeros(32, 16, dtype=torch.int64)

        # case 1: GPT2 117M
        encoder = GPT2Decoder()
        _ = encoder(inputs)
        self.assertEqual(len(encoder.trainable_variables), 1 + 1 + 12 * 26 + 2)

        # case 2: GPT2 345M
        hparams = {"pretrained_model_name": "345M"}
        encoder = GPT2Decoder(hparams=hparams)
        _ = encoder(inputs)
        self.assertEqual(len(encoder.trainable_variables), 1 + 1 + 24 * 26 + 2)

        # case 3: self-designed GPT2
        hparams = {
            "decoder": {
                "num_blocks": 6,
            },
            "pretrained_model_name": None
        }
        encoder = GPT2Decoder(hparams=hparams)
        _ = encoder(inputs)
        self.assertEqual(len(encoder.trainable_variables), 1 + 1 + 6 * 26 + 2)
Ejemplo n.º 3
0
    def test_hparams(self):
        r"""Tests the priority of the decoder arch parameters.
        """
        # case 1: set "pretrained_mode_name" by constructor argument
        hparams = {
            "pretrained_model_name": "345M",
        }
        decoder = GPT2Decoder(pretrained_model_name="117M", hparams=hparams)
        self.assertEqual(decoder.hparams.num_blocks, 12)
        _ = decoder(self.inputs)

        # case 2: set "pretrained_mode_name" by hparams
        hparams = {
            "pretrained_model_name": "117M",
            "num_blocks": 6,
        }
        decoder = GPT2Decoder(hparams=hparams)
        self.assertEqual(decoder.hparams.num_blocks, 12)
        _ = decoder(self.inputs)

        # case 3: set to None in both hparams and constructor argument
        hparams = {
            "pretrained_model_name": None,
            "num_blocks": 6,
        }
        decoder = GPT2Decoder(hparams=hparams)
        self.assertEqual(decoder.hparams.num_blocks, 6)
        _ = decoder(self.inputs)

        # case 4: using default hparams
        decoder = GPT2Decoder()
        self.assertEqual(decoder.hparams.num_blocks, 12)
        _ = decoder(self.inputs)
Ejemplo n.º 4
0
    def test_hparams(self):
        r"""Tests the priority of the decoer arch parameter.
        """
        inputs = torch.zeros(32, 16, dtype=torch.int64)

        # case 1: set "pretrained_mode_name" by constructor argument
        hparams = {
            "pretrained_model_name": "345M",
        }
        encoder = GPT2Decoder(pretrained_model_name="117M", hparams=hparams)
        _ = encoder(inputs)
        self.assertEqual(encoder.hparams.decoder.num_blocks, 12)

        # case 2: set "pretrained_mode_name" by hparams
        hparams = {
            "pretrained_model_name": "117M",
            "decoder": {
                "num_blocks": 6
            }
        }
        encoder = GPT2Decoder(hparams=hparams)
        _ = encoder(inputs)
        self.assertEqual(encoder.hparams.decoder.num_blocks, 12)

        # case 3: set to None in both hparams and constructor argument
        hparams = {
            "pretrained_model_name": None,
            "decoder": {
                "num_blocks": 6
            },
        }
        encoder = GPT2Decoder(hparams=hparams)
        _ = encoder(inputs)
        self.assertEqual(encoder.hparams.decoder.num_blocks, 6)

        # case 4: using default hparams
        encoder = GPT2Decoder()
        _ = encoder(inputs)
        self.assertEqual(encoder.hparams.decoder.num_blocks, 12)
Ejemplo n.º 5
0
    def test_decode_train(self):
        r"""Tests train_greedy.
        """
        hparams = {"pretrained_model_name": None}
        decoder = GPT2Decoder(hparams=hparams)
        decoder.train()

        inputs = torch.randint(50257, (self.batch_size, self.max_length))
        outputs = decoder(inputs)

        self.assertEqual(outputs.logits.shape,
                         torch.Size([self.batch_size, self.max_length, 50257]))
        self.assertEqual(outputs.sample_id.shape,
                         torch.Size([self.batch_size, self.max_length]))
Ejemplo n.º 6
0
    def test_decode_train(self):
        r"""Tests train_greedy.
        """
        decoder = GPT2Decoder()
        decoder.train()

        max_time = 8
        batch_size = 16
        inputs = torch.randint(50257, (batch_size, max_time),
                               dtype=torch.int64)
        outputs = decoder(inputs)

        self.assertEqual(outputs.logits.shape,
                         torch.Size([batch_size, max_time, 50257]))
        self.assertEqual(outputs.sample_id.shape,
                         torch.Size([batch_size, max_time]))
Ejemplo n.º 7
0
    def test_greedy_embedding_helper(self):
        r"""Tests with tf.contrib.seq2seq.GreedyEmbeddingHelper
        """
        hparams = {
            "pretrained_model_name": None,
        }
        decoder = GPT2Decoder(hparams=hparams)
        decoder.eval()

        start_tokens = torch.full((self.batch_size, ), 1, dtype=torch.int64)
        end_token = 2

        helper = decoder_helpers.GreedyEmbeddingHelper(start_tokens, end_token)

        outputs, length = decoder(helper=helper,
                                  max_decoding_length=self.max_length)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
Ejemplo n.º 8
0
    def test_decode_infer_sample(self):
        r"""Tests infer_sample
        """
        hparams = {
            "pretrained_model_name": None,
        }
        decoder = GPT2Decoder(hparams=hparams)
        decoder.eval()

        start_tokens = torch.full((self.batch_size, ), 1, dtype=torch.int64)
        end_token = 2

        helper = decoder_helpers.SampleEmbeddingHelper(start_tokens, end_token)

        outputs, length = decoder(helper=helper,
                                  max_decoding_length=self.max_length)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
Ejemplo n.º 9
0
    def test_greedy_embedding_helper(self):
        r"""Tests with tf.contrib.seq2seq.GreedyEmbeddingHelper
        """
        decoder = GPT2Decoder()
        decoder.eval()

        start_tokens = torch.full((16, ), 1, dtype=torch.int64)
        end_token = 2
        max_decoding_length = 16

        embedding_fn = lambda x, y: (decoder.word_embedder(x) + decoder.
                                     position_embedder(y))

        helper = decoder_helpers.GreedyEmbeddingHelper(embedding_fn,
                                                       start_tokens, end_token)

        outputs, length = decoder(helper=helper,
                                  max_decoding_length=max_decoding_length)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
Ejemplo n.º 10
0
    def test_decode_infer_sample(self):
        r"""Tests infer_sample
        """
        decoder = GPT2Decoder()
        decoder.eval()

        start_tokens = torch.full((16, ), 1, dtype=torch.int64)
        end_token = 2
        max_decoding_length = 16

        embedding_fn = lambda x, y: (decoder.word_embedder(x) + decoder.
                                     position_embedder(y))

        helper = decoder_helpers.SampleEmbeddingHelper(embedding_fn,
                                                       start_tokens, end_token)

        outputs, length = decoder(helper=helper,
                                  max_decoding_length=max_decoding_length)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
Ejemplo n.º 11
0
    def test_beam_search(self):
        r"""Tests beam_search
        """
        decoder = GPT2Decoder()
        decoder.eval()

        start_tokens = torch.full((16, ), 1, dtype=torch.int64)
        end_token = 2
        max_decoding_length = 16

        embedding_fn = lambda x, y: (decoder.word_embedder(x) + decoder.
                                     position_embedder(y))

        outputs = decoder(embedding=embedding_fn,
                          start_tokens=start_tokens,
                          beam_width=5,
                          end_token=end_token,
                          max_decoding_length=max_decoding_length)

        self.assertEqual(outputs['log_prob'].shape, torch.Size([16, 5]))
        self.assertEqual(outputs['sample_id'].shape, torch.Size([16, 16, 5]))
Ejemplo n.º 12
0
    def test_topk_embedding_helper(self):
        r"""Tests TopKSampleEmbeddingHelper
        """
        hparams = {
            "pretrained_model_name": None,
        }
        decoder = GPT2Decoder(hparams=hparams)
        decoder.eval()

        start_tokens = torch.full((self.batch_size, ), 1, dtype=torch.int64)
        end_token = 2

        helper = decoder_helpers.TopKSampleEmbeddingHelper(
            start_tokens=start_tokens,
            end_token=end_token,
            top_k=40,
            softmax_temperature=0.7)

        outputs, length = decoder(helper=helper,
                                  max_decoding_length=self.max_length)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
Ejemplo n.º 13
0
    def test_beam_search(self):
        r"""Tests beam_search
        """
        hparams = {
            "pretrained_model_name": None,
        }
        decoder = GPT2Decoder(hparams=hparams)
        decoder.eval()

        start_tokens = torch.full((self.batch_size, ), 1, dtype=torch.int64)
        end_token = 2

        outputs = decoder(start_tokens=start_tokens,
                          beam_width=self.beam_width,
                          end_token=end_token,
                          max_decoding_length=self.max_length)

        self.assertEqual(outputs['log_prob'].shape,
                         torch.Size([self.batch_size, self.beam_width]))
        self.assertEqual(
            outputs['sample_id'].shape,
            torch.Size([self.batch_size, self.max_length, self.beam_width]))
Ejemplo n.º 14
0
    def test_topk_embedding_helper(self):
        r"""Tests TopKSampleEmbeddingHelper
        """
        decoder = GPT2Decoder()
        decoder.eval()

        start_tokens = torch.full((16, ), 1, dtype=torch.int64)
        end_token = 2
        max_decoding_length = 16

        embedding_fn = lambda x, y: (decoder.word_embedder(x) + decoder.
                                     position_embedder(y))

        helper = decoder_helpers.TopKSampleEmbeddingHelper(
            embedding=embedding_fn,
            start_tokens=start_tokens,
            end_token=end_token,
            top_k=40,
            softmax_temperature=0.7)

        outputs, length = decoder(max_decoding_length=max_decoding_length,
                                  helper=helper)

        self.assertIsInstance(outputs, TransformerDecoderOutput)