Exemple #1
0
    def _reset(self):
        # type: () -> None
        """Reset tracker to initial state - doesn't delete events though!."""

        self._reset_slots()
        self._paused = False
        self.latest_action_name = None
        self.latest_message = UserUttered.empty()
        self.latest_bot_utterance = BotUttered.empty()
        self.follow_up_action = ACTION_LISTEN_NAME
Exemple #2
0
def test_tracker_entity_retrieval(default_domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0
    assert list(tracker.get_latest_entity_values("entity_name")) == []

    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(
        UserUttered("/greet", intent, [{
            "start": 1,
            "end": 5,
            "value": "greet",
            "entity": "entity_name",
            "extractor": "manual"
        }]))
    assert list(tracker.get_latest_entity_values("entity_name")) == ["greet"]
    assert list(tracker.get_latest_entity_values("unknown")) == []
Exemple #3
0
    def _handle_message_with_tracker(self, message, tracker):
        # type: (UserMessage, DialogueStateTracker) -> None

        if message.parse_data:
            parse_data = message.parse_data
        else:
            parse_data = self._parse_message(message)

        # don't ever directly mutate the tracker
        # - instead pass its events to log
        tracker.update(UserUttered(message.text, parse_data["intent"],
                                   parse_data["entities"], parse_data))
        # store all entities as slots
        for e in self.domain.slots_for_entities(parse_data["entities"]):
            tracker.update(e)

        logger.debug("Logged UserUtterance - "
                     "tracker now has {} events".format(len(tracker.events)))
def test_reminder_aborted(default_processor):
    out = CollectingOutputChannel()
    sender_id = uuid.uuid4().hex

    d = Dispatcher(sender_id, out, default_processor.nlg)
    r = ReminderScheduled("utter_greet", datetime.datetime.now(),
                          kill_on_user_message=True)
    t = default_processor.tracker_store.get_or_create_tracker(sender_id)

    t.update(r)
    t.update(UserUttered("test"))  # cancels the reminder

    default_processor.tracker_store.save(t)
    default_processor.handle_reminder(r, d)

    # retrieve the updated tracker
    t = default_processor.tracker_store.retrieve(sender_id)
    assert len(t.events) == 3  # nothing should have been executed
Exemple #5
0
    def __init__(self,
                 nlu_threshold=0.1,  # type: float
                 confirmation_action_name="action_default_confirmation", # type: Text
                 confirmation_rephrase_name="action_default_rephrase",  # type: Text
                 affirm_intent_name="affirm",  # type: Text
                 deny_intent_name="deny"  # type: Text
                 ):
        # type: (...) -> None

        super(ConfirmationPolicy, self).__init__()

        self.nlu_threshold = nlu_threshold
        self.confirmation_rephrase_name = confirmation_rephrase_name
        self.confirmation_action_name = confirmation_action_name
        self.affirm_intent_name = affirm_intent_name
        self.deny_intent_name = deny_intent_name


        self.cache_message = UserUttered(None)
Exemple #6
0
 def receive_nlu_message(self, message, parse_data):
     tracker = self.message_processor._get_tracker(message.sender_id)
     if tracker:
         tracker.update(
             UserUttered(message.text,
                         parse_data["intent"],
                         parse_data["entities"],
                         parse_data,
                         input_channel=message.input_channel))
         # store all entities as slots
         for e in self.agent.domain.slots_for_entities(
                 parse_data["entities"]):
             tracker.update(e)
         self.predict_and_execute_next_action(message, tracker)
         self.message_processor._save_tracker(tracker)
         if isinstance(message.output_channel, CollectingOutputChannel):
             return message.output_channel.messages
         else:
             return None
     return None
Exemple #7
0
    def handle_reminder(self, reminder_event, dispatcher):
        # type: (ReminderScheduled, Dispatcher) -> None
        """Handle a reminder that is triggered asynchronously."""

        def has_message_after_reminder(evts):
            """If the user sent a message after the reminder got scheduled -
            it might be better to cancel it."""

            for e in reversed(evts):
                if (isinstance(e, ReminderScheduled) and
                        e.name == reminder_event.name):
                    return False
                elif isinstance(e, UserUttered) and e.text:
                    return True
            return True  # tracker has probably been restarted

        tracker = self._get_tracker(dispatcher.sender_id)

        if not tracker:
            logger.warning("Failed to retrieve or create tracker for sender "
                           "'{}'.".format(dispatcher.sender_id))
            return None

        if (reminder_event.kill_on_user_message and
                has_message_after_reminder(tracker.events)):
            logger.debug("Canceled reminder because it is outdated. "
                         "(event: {} id: {})".format(reminder_event.action_name,
                                                     reminder_event.name))
        else:
            # necessary for proper featurization, otherwise the previous
            # unrelated message would influence featurization
            tracker.update(UserUttered.empty())
            action = self._get_action(reminder_event.action_name)
            should_continue = self._run_action(action, tracker, dispatcher)
            if should_continue:
                user_msg = UserMessage(None,
                                       dispatcher.output_channel,
                                       dispatcher.sender_id)
                self._predict_and_execute_next_action(user_msg, tracker)
            # save tracker state to continue conversation from this state
            self._save_tracker(tracker)
Exemple #8
0
    def handle_reminder(self, reminder_event, dispatcher):
        # type: (ReminderScheduled, Dispatcher) -> None
        """Handle a reminder that is triggered asynchronously."""

        def has_message_after_reminder(evts):
            """If the user sent a message after the reminder got scheduled -
            it might be better to cancel it."""

            for e in reversed(evts):
                if (isinstance(e, ReminderScheduled) and
                        e.name == reminder_event.name):
                    return False
                elif isinstance(e, UserUttered):
                    return True
            return True  # tracker has probably been restarted

        tracker = self._get_tracker(dispatcher.sender_id)

        if not tracker:
            logger.warning("Failed to retrieve or create tracker for sender "
                           "'{}'.".format(dispatcher.sender_id))
            return None

        if (reminder_event.kill_on_user_message and
                has_message_after_reminder(tracker.events)):
            logger.debug("Canceled reminder because it is outdated. "
                         "(event: {} id: {})".format(reminder_event.action_name,
                                                     reminder_event.name))
        else:
            # necessary for proper featurization, otherwise the previous
            # unrelated message would influence featurization
            tracker.update(UserUttered.empty())
            action = self._get_action(reminder_event.action_name)
            should_continue = self._run_action(action, tracker, dispatcher)
            if should_continue:
                user_msg = UserMessage(None,
                                       dispatcher.output_channel,
                                       dispatcher.sender_id)
                self._predict_and_execute_next_action(user_msg, tracker)
            # save tracker state to continue conversation from this state
            self._save_tracker(tracker)
Exemple #9
0
 def add_user_messages(self, messages, line_num):
     if not self.current_step_builder:
         raise StoryParseError("User message '{}' at invalid location. "
                               "Expected story start.".format(messages))
     parsed_messages = []
     for m in messages:
         parse_data = self.interpreter.parse(m)
         utterance = UserUttered.from_parse_data(m, parse_data)
         if m.startswith("_"):
             c = utterance.as_story_string()
             logger.warn("Stating user intents with a leading '_' is "
                         "deprecated. The new format is "
                         "'* {}'. Please update "
                         "your example '{}' to the new format.".format(c, m))
         intent_name = utterance.intent.get("name")
         if intent_name not in self.domain.intents:
             logger.warn("Found unknown intent '{}' on line {}. Please, "
                         "make sure that all intents are listed in your "
                         "domain yaml.".format(intent_name, line_num))
         parsed_messages.append(utterance)
     self.current_step_builder.add_user_messages(parsed_messages)
Exemple #10
0
def test_can_read_test_story(default_domain):
    trackers = extract_trackers_from_file("data/test_stories/stories.md",
                                          default_domain,
                                          featurizer=BinaryFeaturizer())
    assert len(trackers) == 7
    # this should be the story simple_story_with_only_end -> show_it_all
    # the generated stories are in a non stable order - therefore we need to
    # do some trickery to find the one we want to test
    tracker = [t for t in trackers if len(t.events) == 5][0]
    assert tracker.events[0] == ActionExecuted("action_listen")
    assert tracker.events[1] == UserUttered(
            "simple",
            intent={"name": "simple", "confidence": 1.0},
            parse_data={'text': 'simple',
                        'intent_ranking': [{'confidence': 1.0,
                                            'name': 'simple'}],
                        'intent': {'confidence': 1.0, 'name': 'simple'},
                        'entities': []})
    assert tracker.events[2] == ActionExecuted("utter_default")
    assert tracker.events[3] == ActionExecuted("utter_greet")
    assert tracker.events[4] == ActionExecuted("action_listen")
Exemple #11
0
def test_query_form_set_username_directly():
    domain = TemplateDomain.load("data/test_domains/query_form.yml")
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-form"
    dispatcher = Dispatcher(sender_id, out, domain)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # pre-fill username slot
    username = "******"
    tracker.update(SlotSet('username', username))

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "query"
    assert username in last_message.text
Exemple #12
0
def test_restaurant_form_skipahead():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, domain)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [{"entity": "cuisine", "value": "chinese"},
                {"entity": "number", "value": 8}]
    tracker.update(UserUttered("",
                               intent={"name": "inform"},
                               entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    print(events[0].as_story_string())
    print(events[1].as_story_string())
    assert len(events) == 3
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
def test_revert_action_event(default_domain):
    tracker = DialogueStateTracker("default", default_domain.slots,
                                   default_domain.topics,
                                   default_domain.default_topic)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
    tracker.update(UserUttered("/greet", intent, []))
    tracker.update(ActionExecuted("my_action"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    # Expecting count of 4:
    #   +3 executed actions
    #   +1 final state
    assert tracker.latest_action_name == ACTION_LISTEN_NAME
    assert len(list(tracker.generate_all_prior_states())) == 4

    tracker.update(ActionReverted())

    # Expecting count of 3:
    #   +3 executed actions
    #   +1 final state
    #   -1 reverted action
    assert tracker.latest_action_name == "my_action"
    assert len(list(tracker.generate_all_prior_states())) == 3

    dialogue = tracker.as_dialogue()

    recovered = DialogueStateTracker("default", default_domain.slots,
                                     default_domain.topics,
                                     default_domain.default_topic)
    recovered.recreate_from_dialogue(dialogue)

    assert recovered.current_state() == tracker.current_state()
    assert tracker.latest_action_name == "my_action"
    assert len(list(tracker.generate_all_prior_states())) == 3
def test_restart_event(default_domain):
    tracker = DialogueStateTracker("default", default_domain.slots,
                                   default_domain.topics,
                                   default_domain.default_topic)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
    tracker.update(UserUttered("/greet", intent, []))
    tracker.update(ActionExecuted("my_action"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    assert len(tracker.events) == 4
    assert tracker.latest_message.text == "/greet"
    assert len(list(tracker.generate_all_prior_states())) == 4

    tracker.update(Restarted())

    assert len(tracker.events) == 5
    assert tracker.follow_up_action is not None
    assert tracker.follow_up_action.name() == ACTION_LISTEN_NAME
    assert tracker.latest_message.text is None
    assert len(list(tracker.generate_all_prior_states())) == 1

    dialogue = tracker.as_dialogue()

    recovered = DialogueStateTracker("default", default_domain.slots,
                                     default_domain.topics,
                                     default_domain.default_topic)
    recovered.recreate_from_dialogue(dialogue)

    assert recovered.current_state() == tracker.current_state()
    assert len(recovered.events) == 5
    assert tracker.follow_up_action is not None
    assert tracker.follow_up_action.name() == ACTION_LISTEN_NAME
    assert recovered.latest_message.text is None
    assert len(list(recovered.generate_all_prior_states())) == 1
Exemple #15
0
def test_policy_priority():
    domain = Domain.load("data/test_domains/default.yml")
    tracker = DialogueStateTracker.from_events("test", [UserUttered("hi")], [])

    priority_1 = ConstantPolicy(priority=1, predict_index=0)
    priority_2 = ConstantPolicy(priority=2, predict_index=1)

    policy_ensemble_0 = SimplePolicyEnsemble([priority_1, priority_2])
    policy_ensemble_1 = SimplePolicyEnsemble([priority_2, priority_1])

    priority_2_result = priority_2.predict_action_probabilities(
        tracker, domain)

    i = 1  # index of priority_2 in ensemble_0
    result, best_policy = policy_ensemble_0.probabilities_using_best_policy(
        tracker, domain)
    assert best_policy == 'policy_{}_{}'.format(i, type(priority_2).__name__)
    assert (result.tolist() == priority_2_result)

    i = 0  # index of priority_2 in ensemble_1
    result, best_policy = policy_ensemble_1.probabilities_using_best_policy(
        tracker, domain)
    assert best_policy == 'policy_{}_{}'.format(i, type(priority_2).__name__)
    assert (result.tolist() == priority_2_result)
Exemple #16
0
    def _handle_message_with_tracker(self, message, tracker):
        # type: (UserMessage, DialogueStateTracker) -> None

        parse_data = self._parse_message(message)

        if self.rules is not None:
            self.rules.substitute_intent(parse_data, tracker)
            self.rules.filter_entities(parse_data)

            error_template = self.rules.input_validation.get_error(parse_data, tracker)
            if error_template is not None:
                self._utter_error_and_roll_back(message, tracker, error_template)
                return

        # don't ever directly mutate the tracker
        # - instead pass its events to log
        tracker.update(UserUttered(message.text, parse_data["intent"],
                                   parse_data["entities"], parse_data))
        # store all entities as slots
        for e in self.domain.slots_for_entities(parse_data["entities"]):
            tracker.update(e)

        logger.debug("Logged UserUtterance - "
                     "tracker now has {} events".format(len(tracker.events)))
def test_tracker_store_storage_and_retrieval(store):
    tracker = store.get_or_create_tracker("some-id")
    # the retrieved tracker should be empty
    assert tracker.sender_id == "some-id"

    # Action listen should be in there
    assert list(tracker.events) == [ActionExecuted(ACTION_LISTEN_NAME)]

    # lets log a test message
    intent = {"name": "greet", "confidence": 1.0}
    tracker.update(UserUttered("/greet", intent, []))
    assert tracker.latest_message.intent.get("name") == "greet"
    store.save(tracker)

    # retrieving the same tracker should result in the same tracker
    retrieved_tracker = store.get_or_create_tracker("some-id")
    assert retrieved_tracker.sender_id == "some-id"
    assert len(retrieved_tracker.events) == 2
    assert retrieved_tracker.latest_message.intent.get("name") == "greet"

    # getting another tracker should result in an empty tracker again
    other_tracker = store.get_or_create_tracker("some-other-id")
    assert other_tracker.sender_id == "some-other-id"
    assert len(other_tracker.events) == 1
Exemple #18
0
import json

from rasa_core import broker, utils
from rasa_core.broker import FileProducer, PikaProducer
from rasa_core.events import Event, Restarted, SlotSet, UserUttered
from rasa_core.utils import EndpointConfig
from tests.conftest import DEFAULT_ENDPOINTS_FILE

TEST_EVENTS = [
    UserUttered("/greet", {
        "name": "greet",
        "confidence": 1.0
    }, []),
    SlotSet("name", "rasa"),
    Restarted()
]


def test_pika_broker_from_config():
    cfg = utils.read_endpoint_config(
        'data/test_endpoints/event_brokers/'
        'pika_endpoint.yml', "event_broker")
    actual = broker.from_endpoint_config(cfg)

    assert isinstance(actual, PikaProducer)
    assert actual.host == "localhost"
    assert actual.credentials.username == "username"
    assert actual.queue == "queue"


def test_no_broker_in_config():
Exemple #19
0
from __future__ import unicode_literals

from datetime import datetime
from copy import deepcopy

import pytest

from rasa_core.events import (Event, UserUttered, TopicSet, SlotSet, Restarted,
                              ActionExecuted, AllSlotsReset, ReminderScheduled,
                              ConversationResumed, ConversationPaused,
                              StoryExported, ActionReverted, BotUttered)


@pytest.mark.parametrize("one_event,another_event", [
    (UserUttered("/greet", {
        "name": "greet",
        "confidence": 1.0
    }, []), UserUttered("/goodbye", {
        "name": "goodbye",
        "confidence": 1.0
    }, [])),
    (TopicSet("my_topic"), TopicSet("my_other_topic")),
    (SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")),
    (Restarted(), None),
    (AllSlotsReset(), None),
    (ConversationPaused(), None),
    (ConversationResumed(), None),
    (StoryExported(), None),
    (ActionReverted(), None),
    (ActionExecuted("my_action"), ActionExecuted("my_other_action")),
    (BotUttered("my_text",
                "my_data"), BotUttered("my_other_test", "my_other_data")),
Exemple #20
0
def user_uttered(text: Text, confidence: float) -> UserUttered:
    parse_data = {'intent': {'name': text, 'confidence': confidence}}
    return UserUttered(text='Random', intent=parse_data['intent'],
                       parse_data=parse_data)
Exemple #21
0
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from copy import deepcopy

import pytest

from rasa_core.events import UserUttered, TopicSet, SlotSet, Restarted, \
    ActionExecuted, AllSlotsReset, \
    ReminderScheduled, ConversationResumed, ConversationPaused, StoryExported, \
    ActionReverted, BotUttered


@pytest.mark.parametrize("one_event,another_event", [
    (UserUttered("/greet", "greet", []), UserUttered("/goodbye", "goodbye",
                                                     [])),
    (TopicSet("my_topic"), TopicSet("my_other_topic")),
    (SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")),
    (Restarted(), None),
    (AllSlotsReset(), None),
    (ConversationPaused(), None),
    (ConversationResumed(), None),
    (StoryExported(), None),
    (ActionReverted(), None),
    (ActionExecuted("my_action"), ActionExecuted("my_other_action")),
    (BotUttered("my_text",
                "my_data"), BotUttered("my_other_test", "my_other_data")),
    (ReminderScheduled("my_action",
                       "now"), ReminderScheduled("my_other_action", "now")),
])
Exemple #22
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from copy import deepcopy

import pytest

from rasa_core.events import UserUttered, TopicSet, SlotSet, Restarted, \
    ActionExecuted, AllSlotsReset, \
    ReminderScheduled


@pytest.mark.parametrize("one_event,another_event", [
    (UserUttered("_greet", "greet", []),
     UserUttered("_goodbye", "goodbye", [])),

    (TopicSet("my_topic"),
     TopicSet("my_other_topic")),

    (SlotSet("my_slot", "value"),
     SlotSet("my__other_slot", "value")),

    (Restarted(),
     None),

    (AllSlotsReset(),
     None),

    (ActionExecuted("my_action"),
import pytest
from treq.testing import StubTreq

import rasa_core
from rasa_core.agent import Agent
from rasa_core.events import UserUttered, BotUttered, SlotSet, TopicSet
from rasa_core.interpreter import RegexInterpreter
from rasa_core.policies.scoring_policy import ScoringPolicy
from rasa_core.server import RasaCoreServer
from tests.conftest import DEFAULT_STORIES_FILE

# a couple of event instances that we can use for testing
test_events = [
    UserUttered.from_parse_data("/goodbye", {
        "intent": {"confidence": 1.0, "name": "greet"},
        "entities": []}),
    BotUttered("Welcome!", {"test": True}),
    TopicSet("question"),
    SlotSet("cuisine", 34),
    SlotSet("cuisine", "34"),
    SlotSet("location", None),
    SlotSet("location", [34, "34", None]),
]


@pytest.fixture(scope="module")
def app(core_server):
    """This fixture makes use of the IResource interface of the
    Klein application to mock Rasa Core server."""
    return StubTreq(core_server.app.resource())