示例#1
0
    def create_and_check_xlnet_lm_head(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
    ):
        model = TFXLNetLMHeadModel(config)

        inputs_1 = {"input_ids": input_ids_1, "token_type_ids": segment_ids}

        all_logits_1, mems_1 = model(inputs_1)

        inputs_2 = {
            "input_ids": input_ids_2,
            "mems": mems_1,
            "token_type_ids": segment_ids
        }

        all_logits_2, mems_2 = model(inputs_2)

        inputs_3 = {
            "input_ids": input_ids_q,
            "perm_mask": perm_mask,
            "target_mapping": target_mapping
        }

        logits, _ = model(inputs_3)

        result = {
            "mems_1": [mem.numpy() for mem in mems_1],
            "all_logits_1": all_logits_1.numpy(),
            "mems_2": [mem.numpy() for mem in mems_2],
            "all_logits_2": all_logits_2.numpy(),
        }

        self.parent.assertListEqual(
            list(result["all_logits_1"].shape),
            [self.batch_size, self.seq_length, self.vocab_size])
        self.parent.assertListEqual(
            list(list(mem.shape) for mem in result["mems_1"]),
            [[self.seq_length, self.batch_size, self.hidden_size]] *
            self.num_hidden_layers,
        )

        self.parent.assertListEqual(
            list(result["all_logits_2"].shape),
            [self.batch_size, self.seq_length, self.vocab_size])
        self.parent.assertListEqual(
            list(list(mem.shape) for mem in result["mems_2"]),
            [[self.mem_len, self.batch_size, self.hidden_size]] *
            self.num_hidden_layers,
        )
    def create_and_check_xlnet_lm_head(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
    ):
        model = TFXLNetLMHeadModel(config)

        inputs_1 = {"input_ids": input_ids_1, "token_type_ids": segment_ids}
        all_logits_1, mems_1 = model(inputs_1).to_tuple()

        inputs_2 = {
            "input_ids": input_ids_2,
            "mems": mems_1,
            "token_type_ids": segment_ids
        }
        all_logits_2, mems_2 = model(inputs_2).to_tuple()

        inputs_3 = {
            "input_ids": input_ids_q,
            "perm_mask": perm_mask,
            "target_mapping": target_mapping
        }
        logits, _ = model(inputs_3).to_tuple()

        self.parent.assertEqual(
            all_logits_1.shape,
            (self.batch_size, self.seq_length, self.vocab_size))
        self.parent.assertListEqual(
            [mem.shape for mem in mems_1],
            [(self.seq_length, self.batch_size, self.hidden_size)] *
            self.num_hidden_layers,
        )
        self.parent.assertEqual(
            all_logits_2.shape,
            (self.batch_size, self.seq_length, self.vocab_size))
        self.parent.assertListEqual(
            [mem.shape for mem in mems_2],
            [(self.mem_len, self.batch_size, self.hidden_size)] *
            self.num_hidden_layers,
        )
示例#3
0
    def test_lm_generate_xlnet_base_cased(self):
        model = TFXLNetLMHeadModel.from_pretrained("xlnet-base-cased")
        input_ids = tf.convert_to_tensor(
            [[
                67,
                2840,
                19,
                18,
                1484,
                20,
                965,
                29077,
                8719,
                1273,
                21,
                45,
                273,
                17,
                10,
                15048,
                28,
                27511,
                21,
                4185,
                11,
                41,
                2444,
                9,
                32,
                1025,
                20,
                8719,
                26,
                23,
                673,
                966,
                19,
                29077,
                20643,
                27511,
                20822,
                20643,
                19,
                17,
                6616,
                17511,
                18,
                8978,
                20,
                18,
                777,
                9,
                19233,
                1527,
                17669,
                19,
                24,
                673,
                17,
                28756,
                150,
                12943,
                4354,
                153,
                27,
                442,
                37,
                45,
                668,
                21,
                24,
                256,
                20,
                416,
                22,
                2771,
                4901,
                9,
                12943,
                4354,
                153,
                51,
                24,
                3004,
                21,
                28142,
                23,
                65,
                20,
                18,
                416,
                34,
                24,
                2958,
                22947,
                9,
                1177,
                45,
                668,
                3097,
                13768,
                23,
                103,
                28,
                441,
                148,
                48,
                20522,
                19,
                12943,
                4354,
                153,
                12860,
                34,
                18,
                326,
                27,
                17492,
                684,
                21,
                6709,
                9,
                8585,
                123,
                266,
                19,
                12943,
                4354,
                153,
                6872,
                24,
                3004,
                20,
                18,
                9225,
                2198,
                19,
                12717,
                103,
                22,
                401,
                24,
                6348,
                9,
                12943,
                4354,
                153,
                1068,
                2768,
                2286,
                19,
                33,
                104,
                19,
                176,
                24,
                9313,
                19,
                20086,
                28,
                45,
                10292,
                9,
                4,
                3,
            ]],
            dtype=tf.int32,
        )
        #  In 1991, the remains of Russian Tsar Nicholas II and his family
        #  (except for Alexei and Maria) are discovered.
        #  The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
        #  remainder of the story. 1883 Western Siberia,
        #  a young Grigori Rasputin is asked by his father and a group of men to perform magic.
        #  Rasputin has a vision and denounces one of the men as a horse thief. Although his
        #  father initially slaps him for making such an accusation, Rasputin watches as the
        #  man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
        #  the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
        #  with people, even a bishop, begging for his blessing. """

        expected_output_ids = [
            67,
            2840,
            19,
            18,
            1484,
            20,
            965,
            29077,
            8719,
            1273,
            21,
            45,
            273,
            17,
            10,
            15048,
            28,
            27511,
            21,
            4185,
            11,
            41,
            2444,
            9,
            32,
            1025,
            20,
            8719,
            26,
            23,
            673,
            966,
            19,
            29077,
            20643,
            27511,
            20822,
            20643,
            19,
            17,
            6616,
            17511,
            18,
            8978,
            20,
            18,
            777,
            9,
            19233,
            1527,
            17669,
            19,
            24,
            673,
            17,
            28756,
            150,
            12943,
            4354,
            153,
            27,
            442,
            37,
            45,
            668,
            21,
            24,
            256,
            20,
            416,
            22,
            2771,
            4901,
            9,
            12943,
            4354,
            153,
            51,
            24,
            3004,
            21,
            28142,
            23,
            65,
            20,
            18,
            416,
            34,
            24,
            2958,
            22947,
            9,
            1177,
            45,
            668,
            3097,
            13768,
            23,
            103,
            28,
            441,
            148,
            48,
            20522,
            19,
            12943,
            4354,
            153,
            12860,
            34,
            18,
            326,
            27,
            17492,
            684,
            21,
            6709,
            9,
            8585,
            123,
            266,
            19,
            12943,
            4354,
            153,
            6872,
            24,
            3004,
            20,
            18,
            9225,
            2198,
            19,
            12717,
            103,
            22,
            401,
            24,
            6348,
            9,
            12943,
            4354,
            153,
            1068,
            2768,
            2286,
            19,
            33,
            104,
            19,
            176,
            24,
            9313,
            19,
            20086,
            28,
            45,
            10292,
            9,
            4,
            3,
            19,
            12943,
            4354,
            153,
            27,
            442,
            22,
            2771,
            4901,
            9,
            69,
            27,
            50,
            551,
            22,
            2771,
            4901,
            19,
            21,
            45,
            668,
            21,
            18,
            416,
            41,
            1499,
            22,
            755,
            18,
            14285,
            9,
            12943,
            4354,
            153,
            27,
            1499,
            22,
            642,
            22,
        ]
        #  In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
        #  are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
        #  narrates the remainder of the story. 1883 Western Siberia, a young Grigori Rasputin
        #  is asked by his father and a group of men to perform magic. Rasputin has a vision and
        #  denounces one of the men as a horse thief. Although his father initially slaps
        #  him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
        #  Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
        #  Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
        #  <sep><cls>, Rasputin is asked to perform magic.
        #  He is not able to perform magic, and his father and
        # the men are forced to leave the monastery. Rasputin is forced to return to

        output_ids = model.generate(input_ids, max_length=200, do_sample=False)

        self.assertListEqual(output_ids[0].numpy().tolist(),
                             expected_output_ids)