Exemple #1
0
    def create_and_check_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
        model = TFTransfoXLLMHeadModel(config)

        lm_logits_1, mems_1 = model(input_ids_1).to_tuple()

        inputs = {"input_ids": input_ids_1, "labels": lm_labels}
        _, mems_1 = model(inputs).to_tuple()

        lm_logits_2, mems_2 = model([input_ids_2, mems_1]).to_tuple()

        inputs = {"input_ids": input_ids_1, "mems": mems_1, "labels": lm_labels}

        _, mems_2 = model(inputs).to_tuple()

        self.parent.assertEqual(lm_logits_1.shape, (self.batch_size, self.seq_length, self.vocab_size))
        self.parent.assertListEqual(
            [mem.shape for mem in mems_1],
            [(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
        )

        self.parent.assertEqual(lm_logits_2.shape, (self.batch_size, self.seq_length, self.vocab_size))
        self.parent.assertListEqual(
            [mem.shape for mem in mems_2],
            [(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
        )
Exemple #2
0
        def create_and_check_transfo_xl_lm_head(self, config, input_ids_1,
                                                input_ids_2, lm_labels):
            model = TFTransfoXLLMHeadModel(config)

            lm_logits_1, mems_1 = model(input_ids_1)

            inputs = {"input_ids": input_ids_1, "labels": lm_labels}
            _, mems_1 = model(inputs)

            lm_logits_2, mems_2 = model([input_ids_2, mems_1])

            inputs = {
                "input_ids": input_ids_1,
                "mems": mems_1,
                "labels": lm_labels
            }

            _, mems_2 = model(inputs)

            result = {
                "mems_1": [mem.numpy() for mem in mems_1],
                "lm_logits_1": lm_logits_1.numpy(),
                "mems_2": [mem.numpy() for mem in mems_2],
                "lm_logits_2": lm_logits_2.numpy(),
            }

            self.parent.assertListEqual(
                list(result["lm_logits_1"].shape),
                [self.batch_size, self.seq_length, self.vocab_size])
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
                [[self.mem_len, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers,
            )

            self.parent.assertListEqual(
                list(result["lm_logits_2"].shape),
                [self.batch_size, self.seq_length, self.vocab_size])
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_2"]),
                [[self.mem_len, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers,
            )
Exemple #3
0
    def test_lm_generate_transfo_xl_wt103(self):
        model = TFTransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
        input_ids = tf.convert_to_tensor(
            [[
                33,
                1297,
                2,
                1,
                1009,
                4,
                1109,
                11739,
                4762,
                358,
                5,
                25,
                245,
                22,
                1706,
                17,
                20098,
                5,
                3215,
                21,
                37,
                1110,
                3,
                13,
                1041,
                4,
                24,
                603,
                490,
                2,
                71477,
                20098,
                104447,
                2,
                20961,
                1,
                2604,
                4,
                1,
                329,
                3,
                6224,
                831,
                16002,
                2,
                8,
                603,
                78967,
                29546,
                23,
                803,
                20,
                25,
                416,
                5,
                8,
                232,
                4,
                277,
                6,
                1855,
                4601,
                3,
                29546,
                54,
                8,
                3609,
                5,
                57211,
                49,
                4,
                1,
                277,
                18,
                8,
                1755,
                15691,
                3,
                341,
                25,
                416,
                693,
                42573,
                71,
                17,
                401,
                94,
                31,
                17919,
                2,
                29546,
                7873,
                18,
                1,
                435,
                23,
                11011,
                755,
                5,
                5167,
                3,
                7983,
                98,
                84,
                2,
                29546,
                3267,
                8,
                3609,
                4,
                1,
                4865,
                1075,
                2,
                6087,
                71,
                6,
                346,
                8,
                5854,
                3,
                29546,
                824,
                1400,
                1868,
                2,
                19,
                160,
                2,
                311,
                8,
                5496,
                2,
                20920,
                17,
                25,
                15097,
                3,
                24,
                24,
                0,
            ]],
            dtype=tf.int32,
        )
        #  In 1991 , the remains of Russian Tsar Nicholas II and his family
        #  ( except for Alexei and Maria ) are discovered .
        #  The voice of Nicholas's young son , Tsarevich Alexei Nikolaevich , narrates the
        #  remainder of the story . 1883 Western Siberia ,
        #  a young Grigori Rasputin is asked by his father and a group of men to perform magic .
        #  Rasputin has a vision and denounces one of the men as a horse thief . Although his
        #  father initially slaps him for making such an accusation , Rasputin watches as the
        #  man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
        #  the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
        #  with people , even a bishop , begging for his blessing . <eod> </s> <eos>

        expected_output_ids = [
            33,
            1297,
            2,
            1,
            1009,
            4,
            1109,
            11739,
            4762,
            358,
            5,
            25,
            245,
            22,
            1706,
            17,
            20098,
            5,
            3215,
            21,
            37,
            1110,
            3,
            13,
            1041,
            4,
            24,
            603,
            490,
            2,
            71477,
            20098,
            104447,
            2,
            20961,
            1,
            2604,
            4,
            1,
            329,
            3,
            6224,
            831,
            16002,
            2,
            8,
            603,
            78967,
            29546,
            23,
            803,
            20,
            25,
            416,
            5,
            8,
            232,
            4,
            277,
            6,
            1855,
            4601,
            3,
            29546,
            54,
            8,
            3609,
            5,
            57211,
            49,
            4,
            1,
            277,
            18,
            8,
            1755,
            15691,
            3,
            341,
            25,
            416,
            693,
            42573,
            71,
            17,
            401,
            94,
            31,
            17919,
            2,
            29546,
            7873,
            18,
            1,
            435,
            23,
            11011,
            755,
            5,
            5167,
            3,
            7983,
            98,
            84,
            2,
            29546,
            3267,
            8,
            3609,
            4,
            1,
            4865,
            1075,
            2,
            6087,
            71,
            6,
            346,
            8,
            5854,
            3,
            29546,
            824,
            1400,
            1868,
            2,
            19,
            160,
            2,
            311,
            8,
            5496,
            2,
            20920,
            17,
            25,
            15097,
            3,
            24,
            24,
            0,
            33,
            1,
            1857,
            2,
            1,
            1009,
            4,
            1109,
            11739,
            4762,
            358,
            5,
            25,
            245,
            28,
            1110,
            3,
            13,
            1041,
            4,
            24,
            603,
            490,
            2,
            71477,
            20098,
            104447,
            2,
            20961,
            1,
            2604,
            4,
            1,
            329,
            3,
            0,
        ]
        #  In 1991, the remains of Russian Tsar Nicholas II and his family (
        #  except for Alexei and Maria ) are discovered. The voice of young son,
        #  Tsarevich Alexei Nikolaevich, narrates the remainder of the story.
        #  1883 Western Siberia, a young Grigori Rasputin is asked by his father
        #  and a group of men to perform magic. Rasputin has a vision and
        #  denounces one of the men as a horse thief. Although his father initially
        #  slaps him for making such an accusation, Rasputin watches as the man
        #  is chased outside and beaten. Twenty years later, Rasputin sees a vision
        #  of the Virgin Mary, prompting him to become a priest.
        #  Rasputin quickly becomes famous, with people, even a bishop, begging for
        #  his blessing. <unk> <unk> <eos> In the 1990s, the remains of Russian Tsar
        # Nicholas II and his family were discovered. The voice of <unk> young son,
        # Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>

        output_ids = model.generate(input_ids, max_length=200, do_sample=False)
        self.assertListEqual(output_ids[0].numpy().tolist(),
                             expected_output_ids)
Exemple #4
0
    print("JS List", sorted(js_dict.items(), key=lambda x: x[1]))
    highest_js_word = sorted(js_dict.items(), key=lambda x: x[1],
                             reverse=True)[0][0]

    print("highest JS word", highest_js_word)
    curr_context = curr_context + highest_js_word

    p = get_distribution(model_info, 'GPT2', curr_context, num_return_seqs,
                         current_len + 1)
    q = get_distribution(model_info, 'TransformerXL', curr_context,
                         num_return_seqs, current_len + 1)

    total_js += js(p, q)
    print("CURR CONTEXT", curr_context, "JS", total_js)
    # then autoregressive again on this new current context
    return auto_regressive(model_info, curr_context, num_return_seqs,
                           current_len + 1, max_len, total_js)


model_info = {
    "GPT2": (TFGPT2LMHeadModel.from_pretrained("gpt2"),
             GPT2Tokenizer.from_pretrained("gpt2")),
    "TransformerXL":
    (TFTransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103'),
     TransfoXLTokenizer.from_pretrained('transfo-xl-wt103'))
}
curr_context = sys.argv[1:]
curr_context = ' '.join(curr_context)

auto_regressive(model_info, curr_context, 50, 1, 5, 0)