def test_travel_form(): domain = TemplateDomain.load("data/test_domains/travel_form.yml") tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-travel" dispatcher = Dispatcher(sender_id, out, domain) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchTravel().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "GPE_origin" tracker.update(events[0]) # second user utterance entities = [{"entity": "GPE", "value": "Berlin"}] tracker.update(UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchTravel().run(dispatcher, tracker, domain) for e in events: print(e.as_story_string()) assert len(events) == 2 assert isinstance(events[0], SlotSet) assert events[0].key == "GPE_origin" assert events[0].value == "Berlin" assert events[1].key == "requested_slot" assert events[1].value == "GPE_destination"
def test_restaurant_form_unhappy_1(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, domain) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "cuisine" tracker.update(events[0]) # second user utterance does not provide what's asked tracker.update( UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) print([(e.key, e.value) for e in events]) assert len(events) == 1 assert isinstance(events[0], SlotSet) # same slot requested again assert events[0].key == "requested_slot" assert events[0].value == "cuisine"
def test_people_form(): domain = TemplateDomain.load("data/test_domains/people_form.yml") tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-people" dispatcher = Dispatcher(sender_id, out, domain) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchPeople().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "person_name" tracker.update(events[0]) # second user utterance name = "Rasa Due" tracker.update( UserUttered(name, intent={"name": "inform"})) events = ActionSearchPeople().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "person_name" assert events[0].value == name
def test_reminder_scheduled(default_processor): out = CollectingOutputChannel() sender_id = uuid.uuid4().hex d = Dispatcher(sender_id, out, default_processor.nlg) r = ReminderScheduled("utter_greet", datetime.datetime.now()) t = default_processor.tracker_store.get_or_create_tracker(sender_id) t.update(UserUttered("test")) t.update(ActionExecuted("action_reminder_reminder")) t.update(r) default_processor.tracker_store.save(t) default_processor.handle_reminder(r, d) # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) assert t.events[-4] == UserUttered(None) assert t.events[-3] == ActionExecuted("utter_greet") assert t.events[-2] == BotUttered("hey there None!", { 'elements': None, 'buttons': None, 'attachment': None }) assert t.events[-1] == ActionExecuted("action_listen")
def test_restaurant_form(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, domain) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "cuisine" tracker.update(events[0]) # second user utterance entities = [{"entity": "cuisine", "value": "chinese"}] tracker.update( UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 2 assert isinstance(events[0], SlotSet) assert isinstance(events[1], SlotSet) assert events[0].key == "cuisine" assert events[0].value == "chinese" assert events[1].key == "requested_slot" assert events[1].value == "people"
def _predict_and_execute_next_action(self, message, tracker): # this will actually send the response to the user dispatcher = Dispatcher(message.sender_id, message.output_channel, self.domain) # keep taking actions decided by the policy until it chooses to 'listen' should_predict_another_action = True num_predicted_actions = 0 self._log_slots(tracker) # action loop. predicts actions until we hit action listen while should_predict_another_action and \ self._should_handle_message(tracker) and \ num_predicted_actions < self.max_number_of_predictions: # this actually just calls the policy's method by the same name action = self._get_next_action(tracker) should_predict_another_action = self._run_action( action, tracker, dispatcher) num_predicted_actions += 1 if num_predicted_actions == self.max_number_of_predictions and \ should_predict_another_action: # circuit breaker was tripped logger.warn("Circuit breaker tripped. Stopped predicting " "more actions for sender '{}'".format( tracker.sender_id)) if self.on_circuit_break: # call a registered callback self.on_circuit_break(tracker, dispatcher) logger.debug("Current topic: {}".format(tracker.topic.name))
def test_restaurant_form_skipahead(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, domain) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance entities = [{ "entity": "cuisine", "value": "chinese" }, { "entity": "number", "value": 8 }] tracker.update( UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) s = events[0].as_story_string() print(events[0].as_story_string()) print(events[1].as_story_string()) assert len(events) == 3 assert events[2].key == "requested_slot" assert events[2].value == "vegetarian"
def _run_next_action(self, action_name, message): # type: (Text, UserMessage) -> Dict[Text, Any] """Run the next action communicating with the remote core server.""" tracker = self.core_client.tracker(message.sender_id, self.domain) dispatcher = Dispatcher(message.sender_id, message.output_channel, self.domain) action = self.domain.action_for_name(action_name) # events and return values are used to update # the tracker state after an action has been taken try: action_events = action.run(dispatcher, tracker, self.domain) except Exception: logger.exception("Encountered an exception while running action " "'{}'. Bot will continue, but the actions " "events are lost. Make sure to fix the " "exception in your custom code." "".format(action.name())) action_events = [] # this is similar to what is done in the processor, but instead of # logging the events on the tracker we need to return them to the # remote core instance events = [] for m in dispatcher.latest_bot_messages: events.append(BotUttered(text=m.text, data=m.data)) events.extend(action_events) return self.core_client.continue_core(action_name, events, message.sender_id)
def __init__(self, interpreter, # type: NaturalLanguageInterpreter policy_ensemble, # type: PolicyEnsemble domain, # type: Domain tracker_store, # type: TrackerStore generator, # type: NaturalLanguageGenerator max_number_of_predictions=10, # type: int message_preprocessor=None, # type: Optional[LambdaType] on_circuit_break=None, # type: Optional[LambdaType] create_dispatcher=None, # type: Optional[LambdaType] rules_file=None # type: Optional[str] ): self.rules = Rules(rules_file) if rules_file is not None else None super(SuperMessageProcessor, self).__init__( interpreter, policy_ensemble, domain, tracker_store, generator, max_number_of_predictions, message_preprocessor, on_circuit_break ) self.create_dispatcher = create_dispatcher if self.create_dispatcher is None: self.create_dispatcher = lambda sender_id, output_channel, nlg: Dispatcher(sender_id, output_channel, nlg)
def test_query_form_set_username_in_form(): domain = TemplateDomain.load("data/test_domains/query_form.yml") tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-form" dispatcher = Dispatcher(sender_id, out, domain) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchQuery().run(dispatcher, tracker, domain) last_message = dispatcher.latest_bot_messages[-1] assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "username" assert last_message.text == 'what is your name?' tracker.update(events[0]) # second user utterance username = '******' tracker.update(UserUttered(username, intent={"name": "inform"})) events = ActionSearchQuery().run(dispatcher, tracker, domain) last_message = dispatcher.latest_bot_messages[-1] assert len(events) == 2 assert isinstance(events[0], SlotSet) assert events[0].key == "username" assert events[0].value == username assert events[1].key == "requested_slot" assert events[1].value == "query" assert username in last_message.text
def _utter_error_and_roll_back(self, latest_bot_message, tracker, template): dispatcher = Dispatcher(latest_bot_message.sender_id, latest_bot_message.output_channel, self.domain) action = ActionInvalidUtterance(template) self._run_action(action, tracker, dispatcher)
def test_dispatcher_utter_buttons_from_domain_templ(capsys): domain_file = "examples/restaurant_domain.yml" domain = TemplateDomain.load(domain_file) bot = CollectingOutputChannel() dispatcher = Dispatcher("my-sender", bot, domain) dispatcher.utter_template("utter_ask_price") assert bot.messages[0][1] == "in which price range?" assert bot.messages[1][1] == "1: cheap (cheap)" assert bot.messages[2][1] == "2: expensive (expensive)"
def execute_action(self, sender_id: Text, action: Text, output_channel: OutputChannel, policy: Text, confidence: float) -> DialogueStateTracker: """Handle a single message.""" processor = self.create_processor() dispatcher = Dispatcher(sender_id, output_channel, self.nlg) return processor.execute_action(sender_id, action, dispatcher, policy, confidence)
def test_action(): domain = Domain.load('domain.yml') nlg = TemplatedNaturalLanguageGenerator(domain.templates) dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg) uid = str(uuid.uuid1()) tracker = DialogueStateTracker(uid, domain.slots) # print ("dispatcher,uid,tracker ===", dispatcher, uid, tracker) action = QuoraSearch() action.run(dispatcher, tracker, domain)
def test_dispatcher_utter_buttons_from_domain_templ(capsys): domain_file = "examples/moodbot/domain.yml" domain = TemplateDomain.load(domain_file) bot = CollectingOutputChannel() dispatcher = Dispatcher("my-sender", bot, domain) dispatcher.utter_template("utter_greet") assert bot.messages[0][1] == "Hey! How are you?" assert bot.messages[1][1] == "1: great (great)" assert bot.messages[2][1] == "2: super sad (super sad)"
def test_action(): domain = Domain.load('domain.yml') nlg = TemplatedNaturalLanguageGenerator(domain.templates) dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg) uid = str(uuid.uuid1()) tracker = DialogueStateTracker(uid, domain.slots) action = ActionJoke() action.run(dispatcher, tracker, domain) assert 'norris' in dispatcher.output_channel.latest_output()['text'].lower()
def modified_predict_and_execute_next_action(self, message, tracker): # this will actually send the response to the user response = {} dispatcher = Dispatcher(message.sender_id, message.output_channel, self.domain) # keep taking actions decided by the policy until it chooses to 'listen' should_predict_another_action = True num_predicted_actions = 0 self._log_slots(tracker) # action loop. predicts actions until we hit action listen while (should_predict_another_action and self._should_handle_message(tracker) and num_predicted_actions < self.max_number_of_predictions): # this actually just calls the policy's method by the same name action = self._get_next_action(tracker) if action.name() != "action_listen": response["next_action"] = action.name() should_predict_another_action = self._run_action( action, tracker, dispatcher) num_predicted_actions += 1 if (num_predicted_actions == self.max_number_of_predictions and should_predict_another_action): # circuit breaker was tripped logger.warn("Circuit breaker tripped. Stopped predicting " "more actions for sender '{}'".format( tracker.sender_id)) if self.on_circuit_break: # call a registered callback self.on_circuit_break(tracker, dispatcher) response["tracker"] = tracker.current_state() # we used contexts and next action has changed so we need to change # intent accordingly tmp = response["tracker"]["latest_message"]["intent_ranking"][0][ "name"] response["tracker"]["latest_message"]["intent_ranking"][0][ "name"] = response["tracker"]["latest_message"][ "intent_ranking"][1]["name"] response["tracker"]["latest_message"]["intent_ranking"][1][ "name"] = tmp response["tracker"]["latest_message"]["intent"] = response[ "tracker"]["latest_message"]["intent_ranking"][0] # added to restart the bot when it becomes unstable when a followup action is triggered # due to contexts should_predict_another_action = self._run_action( ActionRestart(), tracker, dispatcher) return response response["tracker"] = tracker.current_state() return response
def test_dispatcher_utter_buttons_from_domain_templ(): domain_file = "examples/moodbot/domain.yml" domain = TemplateDomain.load(domain_file) bot = CollectingOutputChannel() dispatcher = Dispatcher("my-sender", bot, domain) dispatcher.utter_template("utter_greet") assert len(bot.messages) == 1 assert bot.messages[0]['text'] == "Hey! How are you?" assert bot.messages[0]['data'] == [ {'payload': 'great', 'title': 'great'}, {'payload': 'super sad', 'title': 'super sad'} ]
def test_dispatcher_template_invalid_vars(): domain = TemplateDomain( [], [], [], { "my_made_up_template": [{ "text": "a template referencing an invalid {variable}."}]}, [], [], None) bot = CollectingOutputChannel() dispatcher = Dispatcher("my-sender", bot, domain) dispatcher.utter_template("my_made_up_template") collected = dispatcher.output_channel.latest_output() assert collected['text'].startswith( "a template referencing an invalid {variable}.")
def execute_action( self, sender_id, # type: Text action, # type: Text output_channel # type: OutputChannel ): # type: (...) -> DialogueStateTracker """Handle a single message.""" processor = self.create_processor() dispatcher = Dispatcher(sender_id, output_channel, self.nlg) return processor.execute_action(sender_id, action, dispatcher)
def test_dispatcher_template_invalid_vars(): templates = { "my_made_up_template": [{ "text": "a template referencing an invalid {variable}."}]} bot = CollectingOutputChannel() nlg = TemplatedNaturalLanguageGenerator(templates) dispatcher = Dispatcher("my-sender", bot, nlg) tracker = DialogueStateTracker("my-sender", slots=[]) dispatcher.utter_template("my_made_up_template", tracker) collected = dispatcher.output_channel.latest_output() assert collected['text'].startswith( "a template referencing an invalid {variable}.")
def setUp(self): self.not_undestood = ActionNotUnderstood() # set Interpreter (NLU) to Rasa NLU self.interpreter = 'rasa-nlu/models/rasa-nlu/default/socialcompanionnlu' # load the trained agent model self.agent = Agent.load('./models/dialogue', self.interpreter) self.agent.handle_channel(ConsoleInputChannel()) # TODO mock dispatcher, tracker and domain self.dispatcher = Dispatcher(output_channel=ConsoleOutputChannel()) self.tracker = DialogueStateTracker() self.domain = Domain()
def test_dispatcher_utter_buttons_from_domain_templ(default_tracker): domain_file = "examples/moodbot/domain.yml" domain = Domain.load(domain_file) bot = CollectingOutputChannel() nlg = TemplatedNaturalLanguageGenerator(domain.templates) dispatcher = Dispatcher("my-sender", bot, nlg) dispatcher.utter_template("utter_greet", default_tracker) assert len(bot.messages) == 1 assert bot.messages[0]['text'] == "Hey! How are you?" assert bot.messages[0]['data'] == [ {'payload': 'great', 'title': 'great'}, {'payload': 'super sad', 'title': 'super sad'} ]
def _set_and_execute_next_action(self, action_name, message, tracker): """ Sets the next action to action_name then execute :param action_name: String :param tracker: DialogueStateTracker :return: void """ # this will actually send the response to the user dispatcher = Dispatcher(message.sender_id, message.output_channel, self.domain) action = self.domain.action_map[action_name][1] self._log_slots(tracker) self._run_action(action, tracker, dispatcher) self._run_action(ActionListen(), tracker, dispatcher)
def test_restaurant_form_unhappy_2(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance entities = [{ "entity": "cuisine", "value": "chinese" }, { "entity": "number", "value": 8 }] tracker.update( UserUttered("", intent={"name": "inform"}, entities=entities)) # store all entities as slots events = ActionSearchRestaurants().run(dispatcher, tracker, domain) for e in events: tracker.update(e) cuisine = tracker.get_slot("cuisine") people = tracker.get_slot("people") assert cuisine == "chinese" assert people == 8 events = ActionSearchRestaurants().run(dispatcher, tracker, domain) assert len(events) == 3 assert isinstance(events[0], SlotSet) assert events[2].key == "requested_slot" assert events[2].value == "vegetarian" tracker.update(events[2]) # second user utterance does not provide what's asked tracker.update(UserUttered("", intent={"name": "random"})) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) s = events[0].as_story_string() assert len(events) == 1 assert events[0].key == "requested_slot" assert events[0].value == "vegetarian"
def test_reminder_aborted(default_processor): out = CollectingOutputChannel() sender_id = uuid.uuid4().hex d = Dispatcher(sender_id, out, default_processor.nlg) r = ReminderScheduled("utter_greet", datetime.datetime.now(), kill_on_user_message=True) t = default_processor.tracker_store.get_or_create_tracker(sender_id) t.update(r) t.update(UserUttered("test")) # cancels the reminder default_processor.tracker_store.save(t) default_processor.handle_reminder(r, d) # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) assert len(t.events) == 3 # nothing should have been executed
def predict_and_execute_next_action(self, message, tracker): dispatcher = Dispatcher(message.sender_id, message.output_channel, self.message_processor.nlg) # keep taking actions decided by the policy until it chooses to 'listen' should_predict_another_action = True num_predicted_actions = 0 self.log_slots(tracker) # action loop. predicts actions until we hit action listen while (should_predict_another_action and self.should_handle_message(tracker) and num_predicted_actions < self.message_processor.max_number_of_predictions): # this actually just calls the policy's method by the same name probabilities, policy = self.message_processor._get_next_action_probabilities( tracker) max_index = int(np.argmax(probabilities)) if self.message_processor.domain.num_actions <= max_index or max_index < 0: raise Exception("Can not access action at index {}. " "Domain has {} actions.".format( max_index, self.message_processor.domain.num_actions)) action = self.ask_for_action( self.message_processor.domain.action_names[max_index], self.message_processor.action_endpoint) confidence = probabilities[max_index] # action, policy, confidence = self.agent.predict_next_action(tracker) should_predict_another_action = self.run_action( action, tracker, dispatcher, policy, confidence) num_predicted_actions += 1 if (num_predicted_actions == self.message_processor.max_number_of_predictions and should_predict_another_action): # circuit breaker was tripped if self.message_processor.on_circuit_break: # call a registered callback self.message_processor.on_circuit_break(tracker, dispatcher)
def test_query_form_set_username_directly(): domain = TemplateDomain.load("data/test_domains/query_form.yml") nlg = TemplatedNaturalLanguageGenerator(domain.templates) tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-form" dispatcher = Dispatcher(sender_id, out, nlg) tracker = tracker_store.get_or_create_tracker(sender_id) # pre-fill username slot username = "******" tracker.update(SlotSet('username', username)) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchQuery().run(dispatcher, tracker, domain) last_message = dispatcher.latest_bot_messages[-1] assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "query" assert username in last_message.text
def _predict_and_execute_next_action(self, message, tracker): # this will actually send the response to the user dispatcher = Dispatcher(message.sender_id, message.output_channel, self.domain) # keep taking actions decided by the policy until it chooses to 'listen' should_predict_another_action = True num_predicted_actions = 0 self._log_slots(tracker) # action loop. predicts actions until we hit action listen while (should_predict_another_action and self._should_handle_message(tracker) and num_predicted_actions < self.max_number_of_predictions): # this actually just calls the policy's method by the same name action = self._get_next_action(tracker) should_predict_another_action = self._run_action( action, tracker, dispatcher) num_predicted_actions += 1 if (num_predicted_actions == self.max_number_of_predictions and should_predict_another_action): # circuit breaker was tripped logger.warn("Circuit breaker tripped. Stopped predicting " "more actions for sender '{}'".format( tracker.sender_id)) if self.on_circuit_break: # call a registered callback self.on_circuit_break(tracker, dispatcher) # added to restart the bot when it becomes unstable when a followup action is triggered # due to contexts should_predict_another_action = self._run_action( ActionRestart(), tracker, dispatcher)
def default_dispatcher_collecting(default_nlg): bot = CollectingOutputChannel() return Dispatcher("my-sender", bot, default_nlg)