def test_integration_torch_conversation_dialogpt_input_ids(self): tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") model = AutoModelForCausalLM.from_pretrained( "microsoft/DialoGPT-small") nlp = ConversationalPipeline(model=model, tokenizer=tokenizer) conversation_1 = Conversation("hello") inputs = nlp._parse_and_tokenize([conversation_1]) self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]]) conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"]) inputs = nlp._parse_and_tokenize([conversation_2]) self.assertEqual( inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256] ]) inputs = nlp._parse_and_tokenize([conversation_1, conversation_2]) self.assertEqual( inputs["input_ids"].tolist(), [ [ 31373, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256 ], [ 31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256 ], ], )
def test_integration_torch_conversation_blenderbot_400M_input_ids(self): tokenizer = AutoTokenizer.from_pretrained( "facebook/blenderbot-400M-distill") model = AutoModelForSeq2SeqLM.from_pretrained( "facebook/blenderbot-400M-distill") nlp = ConversationalPipeline(model=model, tokenizer=tokenizer) # test1 conversation_1 = Conversation("hello") inputs = nlp._parse_and_tokenize([conversation_1]) self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]]) # test2 conversation_1 = Conversation( "I like lasagne.", past_user_inputs=["hello"], generated_responses=[ " Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie." ], ) inputs = nlp._parse_and_tokenize([conversation_1]) self.assertEqual( inputs["input_ids"].tolist(), [ # This should be compared with the same conversation on ParlAI `safe_interactive` demo. [ 1710, # hello 86, 228, # Double space 228, 946, 304, 398, 6881, 558, 964, 38, 452, 315, 265, 6252, 452, 322, 968, 6884, 3146, 278, 306, 265, 617, 87, 388, 75, 341, 286, 521, 21, 228, # Double space 228, 281, # I like lasagne. 398, 6881, 558, 964, 21, 2, # EOS ] ], )