def test_lm_generate_transfo_xl_wt103(self): model = TFTransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103") input_ids = tf.convert_to_tensor( [[ 33, 1297, 2, 1, 1009, 4, 1109, 11739, 4762, 358, 5, 25, 245, 22, 1706, 17, 20098, 5, 3215, 21, 37, 1110, 3, 13, 1041, 4, 24, 603, 490, 2, 71477, 20098, 104447, 2, 20961, 1, 2604, 4, 1, 329, 3, 6224, 831, 16002, 2, 8, 603, 78967, 29546, 23, 803, 20, 25, 416, 5, 8, 232, 4, 277, 6, 1855, 4601, 3, 29546, 54, 8, 3609, 5, 57211, 49, 4, 1, 277, 18, 8, 1755, 15691, 3, 341, 25, 416, 693, 42573, 71, 17, 401, 94, 31, 17919, 2, 29546, 7873, 18, 1, 435, 23, 11011, 755, 5, 5167, 3, 7983, 98, 84, 2, 29546, 3267, 8, 3609, 4, 1, 4865, 1075, 2, 6087, 71, 6, 346, 8, 5854, 3, 29546, 824, 1400, 1868, 2, 19, 160, 2, 311, 8, 5496, 2, 20920, 17, 25, 15097, 3, 24, 24, 0, ]], dtype=tf.int32, ) # In 1991 , the remains of Russian Tsar Nicholas II and his family # ( except for Alexei and Maria ) are discovered . # The voice of Nicholas's young son , Tsarevich Alexei Nikolaevich , narrates the # remainder of the story . 1883 Western Siberia , # a young Grigori Rasputin is asked by his father and a group of men to perform magic . # Rasputin has a vision and denounces one of the men as a horse thief . Although his # father initially slaps him for making such an accusation , Rasputin watches as the # man is chased outside and beaten . Twenty years later , Rasputin sees a vision of # the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous , # with people , even a bishop , begging for his blessing . <eod> </s> <eos> expected_output_ids = [ 33, 1297, 2, 1, 1009, 4, 1109, 11739, 4762, 358, 5, 25, 245, 22, 1706, 17, 20098, 5, 3215, 21, 37, 1110, 3, 13, 1041, 4, 24, 603, 490, 2, 71477, 20098, 104447, 2, 20961, 1, 2604, 4, 1, 329, 3, 6224, 831, 16002, 2, 8, 603, 78967, 29546, 23, 803, 20, 25, 416, 5, 8, 232, 4, 277, 6, 1855, 4601, 3, 29546, 54, 8, 3609, 5, 57211, 49, 4, 1, 277, 18, 8, 1755, 15691, 3, 341, 25, 416, 693, 42573, 71, 17, 401, 94, 31, 17919, 2, 29546, 7873, 18, 1, 435, 23, 11011, 755, 5, 5167, 3, 7983, 98, 84, 2, 29546, 3267, 8, 3609, 4, 1, 4865, 1075, 2, 6087, 71, 6, 346, 8, 5854, 3, 29546, 824, 1400, 1868, 2, 19, 160, 2, 311, 8, 5496, 2, 20920, 17, 25, 15097, 3, 24, 24, 0, 33, 1, 1857, 2, 1, 1009, 4, 1109, 11739, 4762, 358, 5, 25, 245, 28, 1110, 3, 13, 1041, 4, 24, 603, 490, 2, 71477, 20098, 104447, 2, 20961, 1, 2604, 4, 1, 329, 3, 0, ] # In 1991, the remains of Russian Tsar Nicholas II and his family ( # except for Alexei and Maria ) are discovered. The voice of young son, # Tsarevich Alexei Nikolaevich, narrates the remainder of the story. # 1883 Western Siberia, a young Grigori Rasputin is asked by his father # and a group of men to perform magic. Rasputin has a vision and # denounces one of the men as a horse thief. Although his father initially # slaps him for making such an accusation, Rasputin watches as the man # is chased outside and beaten. Twenty years later, Rasputin sees a vision # of the Virgin Mary, prompting him to become a priest. # Rasputin quickly becomes famous, with people, even a bishop, begging for # his blessing. <unk> <unk> <eos> In the 1990s, the remains of Russian Tsar # Nicholas II and his family were discovered. The voice of <unk> young son, # Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos> output_ids = model.generate(input_ids, max_length=200, do_sample=False) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
print("JS List", sorted(js_dict.items(), key=lambda x: x[1])) highest_js_word = sorted(js_dict.items(), key=lambda x: x[1], reverse=True)[0][0] print("highest JS word", highest_js_word) curr_context = curr_context + highest_js_word p = get_distribution(model_info, 'GPT2', curr_context, num_return_seqs, current_len + 1) q = get_distribution(model_info, 'TransformerXL', curr_context, num_return_seqs, current_len + 1) total_js += js(p, q) print("CURR CONTEXT", curr_context, "JS", total_js) # then autoregressive again on this new current context return auto_regressive(model_info, curr_context, num_return_seqs, current_len + 1, max_len, total_js) model_info = { "GPT2": (TFGPT2LMHeadModel.from_pretrained("gpt2"), GPT2Tokenizer.from_pretrained("gpt2")), "TransformerXL": (TFTransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103'), TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')) } curr_context = sys.argv[1:] curr_context = ' '.join(curr_context) auto_regressive(model_info, curr_context, 50, 1, 5, 0)