def _reset(self): # type: () -> None """Reset tracker to initial state - doesn't delete events though!.""" self._reset_slots() self._paused = False self.latest_action_name = None self.latest_message = UserUttered.empty() self.latest_bot_utterance = BotUttered.empty() self.follow_up_action = ACTION_LISTEN_NAME
def test_tracker_entity_retrieval(default_domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 assert list(tracker.get_latest_entity_values("entity_name")) == [] intent = {"name": "greet", "confidence": 1.0} tracker.update( UserUttered("/greet", intent, [{ "start": 1, "end": 5, "value": "greet", "entity": "entity_name", "extractor": "manual" }])) assert list(tracker.get_latest_entity_values("entity_name")) == ["greet"] assert list(tracker.get_latest_entity_values("unknown")) == []
def _handle_message_with_tracker(self, message, tracker): # type: (UserMessage, DialogueStateTracker) -> None if message.parse_data: parse_data = message.parse_data else: parse_data = self._parse_message(message) # don't ever directly mutate the tracker # - instead pass its events to log tracker.update(UserUttered(message.text, parse_data["intent"], parse_data["entities"], parse_data)) # store all entities as slots for e in self.domain.slots_for_entities(parse_data["entities"]): tracker.update(e) logger.debug("Logged UserUtterance - " "tracker now has {} events".format(len(tracker.events)))
def test_reminder_aborted(default_processor): out = CollectingOutputChannel() sender_id = uuid.uuid4().hex d = Dispatcher(sender_id, out, default_processor.nlg) r = ReminderScheduled("utter_greet", datetime.datetime.now(), kill_on_user_message=True) t = default_processor.tracker_store.get_or_create_tracker(sender_id) t.update(r) t.update(UserUttered("test")) # cancels the reminder default_processor.tracker_store.save(t) default_processor.handle_reminder(r, d) # retrieve the updated tracker t = default_processor.tracker_store.retrieve(sender_id) assert len(t.events) == 3 # nothing should have been executed
def __init__(self, nlu_threshold=0.1, # type: float confirmation_action_name="action_default_confirmation", # type: Text confirmation_rephrase_name="action_default_rephrase", # type: Text affirm_intent_name="affirm", # type: Text deny_intent_name="deny" # type: Text ): # type: (...) -> None super(ConfirmationPolicy, self).__init__() self.nlu_threshold = nlu_threshold self.confirmation_rephrase_name = confirmation_rephrase_name self.confirmation_action_name = confirmation_action_name self.affirm_intent_name = affirm_intent_name self.deny_intent_name = deny_intent_name self.cache_message = UserUttered(None)
def receive_nlu_message(self, message, parse_data): tracker = self.message_processor._get_tracker(message.sender_id) if tracker: tracker.update( UserUttered(message.text, parse_data["intent"], parse_data["entities"], parse_data, input_channel=message.input_channel)) # store all entities as slots for e in self.agent.domain.slots_for_entities( parse_data["entities"]): tracker.update(e) self.predict_and_execute_next_action(message, tracker) self.message_processor._save_tracker(tracker) if isinstance(message.output_channel, CollectingOutputChannel): return message.output_channel.messages else: return None return None
def handle_reminder(self, reminder_event, dispatcher): # type: (ReminderScheduled, Dispatcher) -> None """Handle a reminder that is triggered asynchronously.""" def has_message_after_reminder(evts): """If the user sent a message after the reminder got scheduled - it might be better to cancel it.""" for e in reversed(evts): if (isinstance(e, ReminderScheduled) and e.name == reminder_event.name): return False elif isinstance(e, UserUttered) and e.text: return True return True # tracker has probably been restarted tracker = self._get_tracker(dispatcher.sender_id) if not tracker: logger.warning("Failed to retrieve or create tracker for sender " "'{}'.".format(dispatcher.sender_id)) return None if (reminder_event.kill_on_user_message and has_message_after_reminder(tracker.events)): logger.debug("Canceled reminder because it is outdated. " "(event: {} id: {})".format(reminder_event.action_name, reminder_event.name)) else: # necessary for proper featurization, otherwise the previous # unrelated message would influence featurization tracker.update(UserUttered.empty()) action = self._get_action(reminder_event.action_name) should_continue = self._run_action(action, tracker, dispatcher) if should_continue: user_msg = UserMessage(None, dispatcher.output_channel, dispatcher.sender_id) self._predict_and_execute_next_action(user_msg, tracker) # save tracker state to continue conversation from this state self._save_tracker(tracker)
def handle_reminder(self, reminder_event, dispatcher): # type: (ReminderScheduled, Dispatcher) -> None """Handle a reminder that is triggered asynchronously.""" def has_message_after_reminder(evts): """If the user sent a message after the reminder got scheduled - it might be better to cancel it.""" for e in reversed(evts): if (isinstance(e, ReminderScheduled) and e.name == reminder_event.name): return False elif isinstance(e, UserUttered): return True return True # tracker has probably been restarted tracker = self._get_tracker(dispatcher.sender_id) if not tracker: logger.warning("Failed to retrieve or create tracker for sender " "'{}'.".format(dispatcher.sender_id)) return None if (reminder_event.kill_on_user_message and has_message_after_reminder(tracker.events)): logger.debug("Canceled reminder because it is outdated. " "(event: {} id: {})".format(reminder_event.action_name, reminder_event.name)) else: # necessary for proper featurization, otherwise the previous # unrelated message would influence featurization tracker.update(UserUttered.empty()) action = self._get_action(reminder_event.action_name) should_continue = self._run_action(action, tracker, dispatcher) if should_continue: user_msg = UserMessage(None, dispatcher.output_channel, dispatcher.sender_id) self._predict_and_execute_next_action(user_msg, tracker) # save tracker state to continue conversation from this state self._save_tracker(tracker)
def add_user_messages(self, messages, line_num): if not self.current_step_builder: raise StoryParseError("User message '{}' at invalid location. " "Expected story start.".format(messages)) parsed_messages = [] for m in messages: parse_data = self.interpreter.parse(m) utterance = UserUttered.from_parse_data(m, parse_data) if m.startswith("_"): c = utterance.as_story_string() logger.warn("Stating user intents with a leading '_' is " "deprecated. The new format is " "'* {}'. Please update " "your example '{}' to the new format.".format(c, m)) intent_name = utterance.intent.get("name") if intent_name not in self.domain.intents: logger.warn("Found unknown intent '{}' on line {}. Please, " "make sure that all intents are listed in your " "domain yaml.".format(intent_name, line_num)) parsed_messages.append(utterance) self.current_step_builder.add_user_messages(parsed_messages)
def test_can_read_test_story(default_domain): trackers = extract_trackers_from_file("data/test_stories/stories.md", default_domain, featurizer=BinaryFeaturizer()) assert len(trackers) == 7 # this should be the story simple_story_with_only_end -> show_it_all # the generated stories are in a non stable order - therefore we need to # do some trickery to find the one we want to test tracker = [t for t in trackers if len(t.events) == 5][0] assert tracker.events[0] == ActionExecuted("action_listen") assert tracker.events[1] == UserUttered( "simple", intent={"name": "simple", "confidence": 1.0}, parse_data={'text': 'simple', 'intent_ranking': [{'confidence': 1.0, 'name': 'simple'}], 'intent': {'confidence': 1.0, 'name': 'simple'}, 'entities': []}) assert tracker.events[2] == ActionExecuted("utter_default") assert tracker.events[3] == ActionExecuted("utter_greet") assert tracker.events[4] == ActionExecuted("action_listen")
def test_query_form_set_username_directly(): domain = TemplateDomain.load("data/test_domains/query_form.yml") tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-form" dispatcher = Dispatcher(sender_id, out, domain) tracker = tracker_store.get_or_create_tracker(sender_id) # pre-fill username slot username = "******" tracker.update(SlotSet('username', username)) # first user utterance tracker.update(UserUttered("", intent={"name": "inform"})) events = ActionSearchQuery().run(dispatcher, tracker, domain) last_message = dispatcher.latest_bot_messages[-1] assert len(events) == 1 assert isinstance(events[0], SlotSet) assert events[0].key == "requested_slot" assert events[0].value == "query" assert username in last_message.text
def test_restaurant_form_skipahead(): domain = TemplateDomain.load("data/test_domains/restaurant_form.yml") tracker_store = InMemoryTrackerStore(domain) out = CollectingOutputChannel() sender_id = "test-restaurant" dispatcher = Dispatcher(sender_id, out, domain) tracker = tracker_store.get_or_create_tracker(sender_id) # first user utterance entities = [{"entity": "cuisine", "value": "chinese"}, {"entity": "number", "value": 8}] tracker.update(UserUttered("", intent={"name": "inform"}, entities=entities)) events = ActionSearchRestaurants().run(dispatcher, tracker, domain) s = events[0].as_story_string() print(events[0].as_story_string()) print(events[1].as_story_string()) assert len(events) == 3 assert events[2].key == "requested_slot" assert events[2].value == "vegetarian"
def test_revert_action_event(default_domain): tracker = DialogueStateTracker("default", default_domain.slots, default_domain.topics, default_domain.default_topic) # the retrieved tracker should be empty assert len(tracker.events) == 0 intent = {"name": "greet", "confidence": 1.0} tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) tracker.update(UserUttered("/greet", intent, [])) tracker.update(ActionExecuted("my_action")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) # Expecting count of 4: # +3 executed actions # +1 final state assert tracker.latest_action_name == ACTION_LISTEN_NAME assert len(list(tracker.generate_all_prior_states())) == 4 tracker.update(ActionReverted()) # Expecting count of 3: # +3 executed actions # +1 final state # -1 reverted action assert tracker.latest_action_name == "my_action" assert len(list(tracker.generate_all_prior_states())) == 3 dialogue = tracker.as_dialogue() recovered = DialogueStateTracker("default", default_domain.slots, default_domain.topics, default_domain.default_topic) recovered.recreate_from_dialogue(dialogue) assert recovered.current_state() == tracker.current_state() assert tracker.latest_action_name == "my_action" assert len(list(tracker.generate_all_prior_states())) == 3
def test_restart_event(default_domain): tracker = DialogueStateTracker("default", default_domain.slots, default_domain.topics, default_domain.default_topic) # the retrieved tracker should be empty assert len(tracker.events) == 0 intent = {"name": "greet", "confidence": 1.0} tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) tracker.update(UserUttered("/greet", intent, [])) tracker.update(ActionExecuted("my_action")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) assert len(tracker.events) == 4 assert tracker.latest_message.text == "/greet" assert len(list(tracker.generate_all_prior_states())) == 4 tracker.update(Restarted()) assert len(tracker.events) == 5 assert tracker.follow_up_action is not None assert tracker.follow_up_action.name() == ACTION_LISTEN_NAME assert tracker.latest_message.text is None assert len(list(tracker.generate_all_prior_states())) == 1 dialogue = tracker.as_dialogue() recovered = DialogueStateTracker("default", default_domain.slots, default_domain.topics, default_domain.default_topic) recovered.recreate_from_dialogue(dialogue) assert recovered.current_state() == tracker.current_state() assert len(recovered.events) == 5 assert tracker.follow_up_action is not None assert tracker.follow_up_action.name() == ACTION_LISTEN_NAME assert recovered.latest_message.text is None assert len(list(recovered.generate_all_prior_states())) == 1
def test_policy_priority(): domain = Domain.load("data/test_domains/default.yml") tracker = DialogueStateTracker.from_events("test", [UserUttered("hi")], []) priority_1 = ConstantPolicy(priority=1, predict_index=0) priority_2 = ConstantPolicy(priority=2, predict_index=1) policy_ensemble_0 = SimplePolicyEnsemble([priority_1, priority_2]) policy_ensemble_1 = SimplePolicyEnsemble([priority_2, priority_1]) priority_2_result = priority_2.predict_action_probabilities( tracker, domain) i = 1 # index of priority_2 in ensemble_0 result, best_policy = policy_ensemble_0.probabilities_using_best_policy( tracker, domain) assert best_policy == 'policy_{}_{}'.format(i, type(priority_2).__name__) assert (result.tolist() == priority_2_result) i = 0 # index of priority_2 in ensemble_1 result, best_policy = policy_ensemble_1.probabilities_using_best_policy( tracker, domain) assert best_policy == 'policy_{}_{}'.format(i, type(priority_2).__name__) assert (result.tolist() == priority_2_result)
def _handle_message_with_tracker(self, message, tracker): # type: (UserMessage, DialogueStateTracker) -> None parse_data = self._parse_message(message) if self.rules is not None: self.rules.substitute_intent(parse_data, tracker) self.rules.filter_entities(parse_data) error_template = self.rules.input_validation.get_error(parse_data, tracker) if error_template is not None: self._utter_error_and_roll_back(message, tracker, error_template) return # don't ever directly mutate the tracker # - instead pass its events to log tracker.update(UserUttered(message.text, parse_data["intent"], parse_data["entities"], parse_data)) # store all entities as slots for e in self.domain.slots_for_entities(parse_data["entities"]): tracker.update(e) logger.debug("Logged UserUtterance - " "tracker now has {} events".format(len(tracker.events)))
def test_tracker_store_storage_and_retrieval(store): tracker = store.get_or_create_tracker("some-id") # the retrieved tracker should be empty assert tracker.sender_id == "some-id" # Action listen should be in there assert list(tracker.events) == [ActionExecuted(ACTION_LISTEN_NAME)] # lets log a test message intent = {"name": "greet", "confidence": 1.0} tracker.update(UserUttered("/greet", intent, [])) assert tracker.latest_message.intent.get("name") == "greet" store.save(tracker) # retrieving the same tracker should result in the same tracker retrieved_tracker = store.get_or_create_tracker("some-id") assert retrieved_tracker.sender_id == "some-id" assert len(retrieved_tracker.events) == 2 assert retrieved_tracker.latest_message.intent.get("name") == "greet" # getting another tracker should result in an empty tracker again other_tracker = store.get_or_create_tracker("some-other-id") assert other_tracker.sender_id == "some-other-id" assert len(other_tracker.events) == 1
import json from rasa_core import broker, utils from rasa_core.broker import FileProducer, PikaProducer from rasa_core.events import Event, Restarted, SlotSet, UserUttered from rasa_core.utils import EndpointConfig from tests.conftest import DEFAULT_ENDPOINTS_FILE TEST_EVENTS = [ UserUttered("/greet", { "name": "greet", "confidence": 1.0 }, []), SlotSet("name", "rasa"), Restarted() ] def test_pika_broker_from_config(): cfg = utils.read_endpoint_config( 'data/test_endpoints/event_brokers/' 'pika_endpoint.yml', "event_broker") actual = broker.from_endpoint_config(cfg) assert isinstance(actual, PikaProducer) assert actual.host == "localhost" assert actual.credentials.username == "username" assert actual.queue == "queue" def test_no_broker_in_config():
from __future__ import unicode_literals from datetime import datetime from copy import deepcopy import pytest from rasa_core.events import (Event, UserUttered, TopicSet, SlotSet, Restarted, ActionExecuted, AllSlotsReset, ReminderScheduled, ConversationResumed, ConversationPaused, StoryExported, ActionReverted, BotUttered) @pytest.mark.parametrize("one_event,another_event", [ (UserUttered("/greet", { "name": "greet", "confidence": 1.0 }, []), UserUttered("/goodbye", { "name": "goodbye", "confidence": 1.0 }, [])), (TopicSet("my_topic"), TopicSet("my_other_topic")), (SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")), (Restarted(), None), (AllSlotsReset(), None), (ConversationPaused(), None), (ConversationResumed(), None), (StoryExported(), None), (ActionReverted(), None), (ActionExecuted("my_action"), ActionExecuted("my_other_action")), (BotUttered("my_text", "my_data"), BotUttered("my_other_test", "my_other_data")),
def user_uttered(text: Text, confidence: float) -> UserUttered: parse_data = {'intent': {'name': text, 'confidence': confidence}} return UserUttered(text='Random', intent=parse_data['intent'], parse_data=parse_data)
from __future__ import division from __future__ import print_function from __future__ import unicode_literals from copy import deepcopy import pytest from rasa_core.events import UserUttered, TopicSet, SlotSet, Restarted, \ ActionExecuted, AllSlotsReset, \ ReminderScheduled, ConversationResumed, ConversationPaused, StoryExported, \ ActionReverted, BotUttered @pytest.mark.parametrize("one_event,another_event", [ (UserUttered("/greet", "greet", []), UserUttered("/goodbye", "goodbye", [])), (TopicSet("my_topic"), TopicSet("my_other_topic")), (SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")), (Restarted(), None), (AllSlotsReset(), None), (ConversationPaused(), None), (ConversationResumed(), None), (StoryExported(), None), (ActionReverted(), None), (ActionExecuted("my_action"), ActionExecuted("my_other_action")), (BotUttered("my_text", "my_data"), BotUttered("my_other_test", "my_other_data")), (ReminderScheduled("my_action", "now"), ReminderScheduled("my_other_action", "now")), ])
from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals from copy import deepcopy import pytest from rasa_core.events import UserUttered, TopicSet, SlotSet, Restarted, \ ActionExecuted, AllSlotsReset, \ ReminderScheduled @pytest.mark.parametrize("one_event,another_event", [ (UserUttered("_greet", "greet", []), UserUttered("_goodbye", "goodbye", [])), (TopicSet("my_topic"), TopicSet("my_other_topic")), (SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")), (Restarted(), None), (AllSlotsReset(), None), (ActionExecuted("my_action"),
import pytest from treq.testing import StubTreq import rasa_core from rasa_core.agent import Agent from rasa_core.events import UserUttered, BotUttered, SlotSet, TopicSet from rasa_core.interpreter import RegexInterpreter from rasa_core.policies.scoring_policy import ScoringPolicy from rasa_core.server import RasaCoreServer from tests.conftest import DEFAULT_STORIES_FILE # a couple of event instances that we can use for testing test_events = [ UserUttered.from_parse_data("/goodbye", { "intent": {"confidence": 1.0, "name": "greet"}, "entities": []}), BotUttered("Welcome!", {"test": True}), TopicSet("question"), SlotSet("cuisine", 34), SlotSet("cuisine", "34"), SlotSet("location", None), SlotSet("location", [34, "34", None]), ] @pytest.fixture(scope="module") def app(core_server): """This fixture makes use of the IResource interface of the Klein application to mock Rasa Core server.""" return StubTreq(core_server.app.resource())