Пример #1
0
        def create_and_check_transfo_xl_lm_head(self, config, input_ids_1,
                                                input_ids_2, lm_labels):
            model = TFTransfoXLLMHeadModel(config)

            lm_logits_1, mems_1 = model(input_ids_1)

            inputs = {"input_ids": input_ids_1, "labels": lm_labels}
            _, mems_1 = model(inputs)

            lm_logits_2, mems_2 = model([input_ids_2, mems_1])

            inputs = {
                "input_ids": input_ids_1,
                "mems": mems_1,
                "labels": lm_labels
            }

            _, mems_2 = model(inputs)

            result = {
                "mems_1": [mem.numpy() for mem in mems_1],
                "lm_logits_1": lm_logits_1.numpy(),
                "mems_2": [mem.numpy() for mem in mems_2],
                "lm_logits_2": lm_logits_2.numpy(),
            }

            self.parent.assertListEqual(
                list(result["lm_logits_1"].shape),
                [self.batch_size, self.seq_length, self.vocab_size])
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
                [[self.mem_len, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers,
            )

            self.parent.assertListEqual(
                list(result["lm_logits_2"].shape),
                [self.batch_size, self.seq_length, self.vocab_size])
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_2"]),
                [[self.mem_len, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers,
            )
Пример #2
0
    def test_lm_generate_transfo_xl_wt103(self):
        model = TFTransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
        input_ids = tf.convert_to_tensor(
            [
                [
                    33,
                    1297,
                    2,
                    1,
                    1009,
                    4,
                    1109,
                    11739,
                    4762,
                    358,
                    5,
                    25,
                    245,
                    22,
                    1706,
                    17,
                    20098,
                    5,
                    3215,
                    21,
                    37,
                    1110,
                    3,
                    13,
                    1041,
                    4,
                    24,
                    603,
                    490,
                    2,
                    71477,
                    20098,
                    104447,
                    2,
                    20961,
                    1,
                    2604,
                    4,
                    1,
                    329,
                    3,
                    6224,
                    831,
                    16002,
                    2,
                    8,
                    603,
                    78967,
                    29546,
                    23,
                    803,
                    20,
                    25,
                    416,
                    5,
                    8,
                    232,
                    4,
                    277,
                    6,
                    1855,
                    4601,
                    3,
                    29546,
                    54,
                    8,
                    3609,
                    5,
                    57211,
                    49,
                    4,
                    1,
                    277,
                    18,
                    8,
                    1755,
                    15691,
                    3,
                    341,
                    25,
                    416,
                    693,
                    42573,
                    71,
                    17,
                    401,
                    94,
                    31,
                    17919,
                    2,
                    29546,
                    7873,
                    18,
                    1,
                    435,
                    23,
                    11011,
                    755,
                    5,
                    5167,
                    3,
                    7983,
                    98,
                    84,
                    2,
                    29546,
                    3267,
                    8,
                    3609,
                    4,
                    1,
                    4865,
                    1075,
                    2,
                    6087,
                    71,
                    6,
                    346,
                    8,
                    5854,
                    3,
                    29546,
                    824,
                    1400,
                    1868,
                    2,
                    19,
                    160,
                    2,
                    311,
                    8,
                    5496,
                    2,
                    20920,
                    17,
                    25,
                    15097,
                    3,
                    24,
                    24,
                    0,
                ]
            ],
            dtype=tf.int31,
        )
        #  In 1991 , the remains of Russian Tsar Nicholas II and his family
        #  ( except for Alexei and Maria ) are discovered .
        #  The voice of Nicholas's young son , Tsarevich Alexei Nikolaevich , narrates the
        #  remainder of the story . 1883 Western Siberia ,
        #  a young Grigori Rasputin is asked by his father and a group of men to perform magic .
        #  Rasputin has a vision and denounces one of the men as a horse thief . Although his
        #  father initially slaps him for making such an accusation , Rasputin watches as the
        #  man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
        #  the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
        #  with people , even a bishop , begging for his blessing . <eod> </s> <eos>

        expected_output_ids = [
            33,
            1297,
            2,
            1,
            1009,
            4,
            1109,
            11739,
            4762,
            358,
            5,
            25,
            245,
            22,
            1706,
            17,
            20098,
            5,
            3215,
            21,
            37,
            1110,
            3,
            13,
            1041,
            4,
            24,
            603,
            490,
            2,
            71477,
            20098,
            104447,
            2,
            20961,
            1,
            2604,
            4,
            1,
            329,
            3,
            6224,
            831,
            16002,
            2,
            8,
            603,
            78967,
            29546,
            23,
            803,
            20,
            25,
            416,
            5,
            8,
            232,
            4,
            277,
            6,
            1855,
            4601,
            3,
            29546,
            54,
            8,
            3609,
            5,
            57211,
            49,
            4,
            1,
            277,
            18,
            8,
            1755,
            15691,
            3,
            341,
            25,
            416,
            693,
            42573,
            71,
            17,
            401,
            94,
            31,
            17919,
            2,
            29546,
            7873,
            18,
            1,
            435,
            23,
            11011,
            755,
            5,
            5167,
            3,
            7983,
            98,
            84,
            2,
            29546,
            3267,
            8,
            3609,
            4,
            1,
            4865,
            1075,
            2,
            6087,
            71,
            6,
            346,
            8,
            5854,
            3,
            29546,
            824,
            1400,
            1868,
            2,
            19,
            160,
            2,
            311,
            8,
            5496,
            2,
            20920,
            17,
            25,
            15097,
            3,
            24,
            24,
            0,
            33,
            1,
            1857,
            2,
            1,
            1009,
            4,
            1109,
            11739,
            4762,
            358,
            5,
            25,
            245,
            28,
            1110,
            3,
            13,
            1041,
            4,
            24,
            603,
            490,
            2,
            71477,
            20098,
            104447,
            2,
            20961,
            1,
            2604,
            4,
            1,
            329,
            3,
            0,
        ]
        #  In 1991, the remains of Russian Tsar Nicholas II and his family (
        #  except for Alexei and Maria ) are discovered. The voice of young son,
        #  Tsarevich Alexei Nikolaevich, narrates the remainder of the story.
        #  1883 Western Siberia, a young Grigori Rasputin is asked by his father
        #  and a group of men to perform magic. Rasputin has a vision and
        #  denounces one of the men as a horse thief. Although his father initially
        #  slaps him for making such an accusation, Rasputin watches as the man
        #  is chased outside and beaten. Twenty years later, Rasputin sees a vision
        #  of the Virgin Mary, prompting him to become a priest.
        #  Rasputin quickly becomes famous, with people, even a bishop, begging for
        #  his blessing. <unk> <unk> <eos> In the 1990s, the remains of Russian Tsar
        # Nicholas II and his family were discovered. The voice of <unk> young son,
        # Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>

        # TODO: add this test when trasnfo-xl-lmhead is implemented
        with self.assertRaises(NotImplementedError):
            model.generate(input_ids, max_length=200, do_sample=False)
            print(expected_output_ids)