def test_integration_torch_conversation_truncated_history(self): # When nlp = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM) conversation_1 = Conversation( "Going to the movies tonight - any suggestions?") # Then self.assertEqual(len(conversation_1.past_user_inputs), 0) # When result = nlp(conversation_1, do_sample=False, max_length=36) # Then self.assertEqual(result, conversation_1) self.assertEqual(len(result.past_user_inputs), 1) self.assertEqual(len(result.generated_responses), 1) self.assertEqual(result.past_user_inputs[0], "Going to the movies tonight - any suggestions?") self.assertEqual(result.generated_responses[0], "The Big Lebowski") # When conversation_1.add_user_input("Is it an action movie?") result = nlp(conversation_1, do_sample=False, max_length=36) # Then self.assertEqual(result, conversation_1) self.assertEqual(len(result.past_user_inputs), 2) self.assertEqual(len(result.generated_responses), 2) self.assertEqual(result.past_user_inputs[1], "Is it an action movie?") self.assertEqual(result.generated_responses[1], "It's a comedy.")
def test_integration_torch_conversation(self): # When nlp = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM) conversation_1 = Conversation("Going to the movies tonight - any suggestions?") conversation_2 = Conversation("What's the last book you have read?") # Then self.assertEqual(len(conversation_1.past_user_inputs), 0) self.assertEqual(len(conversation_2.past_user_inputs), 0) # When result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000) # Then self.assertEqual(result, [conversation_1, conversation_2]) self.assertEqual(len(result[0].past_user_inputs), 1) self.assertEqual(len(result[1].past_user_inputs), 1) self.assertEqual(len(result[0].generated_responses), 1) self.assertEqual(len(result[1].generated_responses), 1) self.assertEqual(result[0].past_user_inputs[0], "Going to the movies tonight - any suggestions?") self.assertEqual(result[0].generated_responses[0], "The Big Lebowski") self.assertEqual(result[1].past_user_inputs[0], "What's the last book you have read?") self.assertEqual(result[1].generated_responses[0], "The Last Question") # When conversation_2.add_user_input("Why do you recommend it?") result = nlp(conversation_2, do_sample=False, max_length=1000) # Then self.assertEqual(result, conversation_2) self.assertEqual(len(result.past_user_inputs), 2) self.assertEqual(len(result.generated_responses), 2) self.assertEqual(result.past_user_inputs[1], "Why do you recommend it?") self.assertEqual(result.generated_responses[1], "It's a good book.")
def _test_conversation_pipeline(self, nlp): valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]] invalid_inputs = ["Hi there!", Conversation()] self.assertIsNotNone(nlp) mono_result = nlp(valid_inputs[0]) self.assertIsInstance(mono_result, Conversation) multi_result = nlp(valid_inputs[1]) self.assertIsInstance(multi_result, list) self.assertIsInstance(multi_result[0], Conversation) # Inactive conversations passed to the pipeline raise a ValueError self.assertRaises(ValueError, nlp, valid_inputs[1]) for bad_input in invalid_inputs: self.assertRaises(Exception, nlp, bad_input) self.assertRaises(Exception, nlp, invalid_inputs)
cache_dir=cache_dir, ) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, from_tf=False, config=config, cache_dir=cache_dir, ) config.min_length = 2 config.max_length = 1000 print(f"min_length: {config.min_length}") print(f"max_length: {config.max_length}") conversation = Conversation() conversation_manager = ConversationalPipeline(model=model, tokenizer=tokenizer) conversation.add_user_input("Is it an action movie?") conversation_manager([conversation]) print(f"Response: {conversation.generated_responses[-1]}") conversation.add_user_input("Is it a love movie?") conversation_manager([conversation]) print(f"Response: {conversation.generated_responses[-1]}") conversation.add_user_input("What is it about?") conversation_manager([conversation]) print(f"Response: {conversation.generated_responses[-1]}") conversation.add_user_input("Would you recommend it?")