Ejemplo n.º 1
0
def test_dispatcher_template_invalid_vars():
    templates = {
                "my_made_up_template": [{
                    "text": "a template referencing an invalid {variable}."}]}
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    tracker = DialogueStateTracker("my-sender", slots=[])
    dispatcher.utter_template("my_made_up_template", tracker)
    collected = dispatcher.output_channel.latest_output()
    assert collected['text'].startswith(
            "a template referencing an invalid {variable}.")
Ejemplo n.º 2
0
def test_dispatcher_utter_buttons_from_domain_templ(default_tracker):
    domain_file = "examples/moodbot/domain.yml"
    domain = TemplateDomain.load(domain_file)
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    dispatcher.utter_template("utter_greet", default_tracker)
    assert len(bot.messages) == 1
    assert bot.messages[0]['text'] == "Hey! How are you?"
    assert bot.messages[0]['data'] == [
        {'payload': 'great', 'title': 'great'},
        {'payload': 'super sad', 'title': 'super sad'}
    ]
Ejemplo n.º 3
0
def test_dispatcher_utter_buttons_from_domain_templ(default_tracker):
    domain_file = "examples/moodbot/domain.yml"
    domain = Domain.load(domain_file)
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    dispatcher.utter_template("utter_greet", default_tracker)
    assert len(bot.messages) == 1
    assert bot.messages[0]['text'] == "Hey! How are you?"
    assert bot.messages[0]['buttons'] == [{
        'payload': 'great',
        'title': 'great'
    }, {
        'payload': 'super sad',
        'title': 'super sad'
    }]
Ejemplo n.º 4
0
    def _run_next_action(self, action_name, message):
        # type: (Text, UserMessage) -> Dict[Text, Any]
        """Run the next action communicating with the remote core server."""

        tracker = self.core_client.tracker(message.sender_id, self.domain)
        dispatcher = Dispatcher(message.sender_id, message.output_channel,
                                self.domain)

        action = self.domain.action_for_name(action_name)
        # events and return values are used to update
        # the tracker state after an action has been taken
        try:
            action_events = action.run(dispatcher, tracker, self.domain)
        except Exception:
            logger.exception("Encountered an exception while running action "
                             "'{}'. Bot will continue, but the actions "
                             "events are lost. Make sure to fix the "
                             "exception in your custom code."
                             "".format(action.name()))
            action_events = []

        # this is similar to what is done in the processor, but instead of
        # logging the events on the tracker we need to return them to the
        # remote core instance
        events = []
        for m in dispatcher.latest_bot_messages:
            events.append(BotUttered(text=m.text, data=m.data))

        events.extend(action_events)
        return self.core_client.continue_core(action_name, events,
                                              message.sender_id)
Ejemplo n.º 5
0
    def __init__(
            self,
            interpreter,  # type: NaturalLanguageInterpreter
            policy_ensemble,  # type: PolicyEnsemble
            domain,  # type: Domain
            tracker_store,  # type: TrackerStore
            generator,  # type: NaturalLanguageGenerator
            action_endpoint=None,  # type: Optional[EndpointConfig]
            max_number_of_predictions=10,  # type: int
            message_preprocessor=None,  # type: Optional[LambdaType]
            on_circuit_break=None,  # type: Optional[LambdaType]
            create_dispatcher=None,  # type: Optional[LambdaType]
            rules=None  # type: Optional[Rules]
    ):

        self.rules = rules
        super(SuperMessageProcessor,
              self).__init__(interpreter, policy_ensemble, domain,
                             tracker_store, generator, action_endpoint,
                             max_number_of_predictions, message_preprocessor,
                             on_circuit_break)

        self.create_dispatcher = create_dispatcher
        if self.create_dispatcher is None:
            self.create_dispatcher = lambda sender_id, output_channel, nlg: Dispatcher(
                sender_id, output_channel, nlg)
Ejemplo n.º 6
0
def test_travel_form():
    domain = TemplateDomain.load("data/test_domains/travel_form.yml")
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-travel"
    dispatcher = Dispatcher(sender_id, out, domain)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "GPE_origin"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "GPE", "value": "Berlin"}]
    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))
    events = ActionSearchTravel().run(dispatcher, tracker, domain)
    for e in events:
        print(e.as_story_string())
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "GPE_origin"
    assert events[0].value == "Berlin"
    assert events[1].key == "requested_slot"
    assert events[1].value == "GPE_destination"
Ejemplo n.º 7
0
    def __init__(self,
                 interpreter,  # type: NaturalLanguageInterpreter
                 policy_ensemble,  # type: PolicyEnsemble
                 domain,  # type: Domain
                 tracker_store,  # type: TrackerStore
                 max_number_of_predictions=10,  # type: int
                 message_preprocessor=None,  # type: Optional[LambdaType]
                 on_circuit_break=None,  # type: Optional[LambdaType]
                 create_dispatcher=None,  # type: Optional[LambdaType]
                 rules_file=None  # type: Optional[str]
                 ):

        self.rules = Rules(rules_file) if rules_file is not None else None
        super(SuperMessageProcessor, self).__init__(
            interpreter,
            policy_ensemble,
            domain,
            tracker_store,
            max_number_of_predictions,
            message_preprocessor,
            on_circuit_break
        )
        self.create_dispatcher = create_dispatcher
        if self.create_dispatcher is None:
            self.create_dispatcher = lambda sender_id, output_channel, dom: Dispatcher(sender_id, output_channel, dom)
Ejemplo n.º 8
0
def test_people_form():
    domain = TemplateDomain.load("data/test_domains/people_form.yml")
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-people"
    dispatcher = Dispatcher(sender_id, out, domain)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "person_name"
    tracker.update(events[0])

    # second user utterance
    name = "Rasa Due"
    tracker.update(UserUttered(name, intent={"name": "inform"}))

    events = ActionSearchPeople().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    assert events[0].key == "person_name"
    assert events[0].value == name
Ejemplo n.º 9
0
    def _predict_and_execute_next_action(self, message, tracker):
        # this will actually send the response to the user

        dispatcher = Dispatcher(message.sender_id,
                                message.output_channel,
                                self.nlg)
        # keep taking actions decided by the policy until it chooses to 'listen'
        should_predict_another_action = True
        num_predicted_actions = 0

        self._log_slots(tracker)

        # action loop. predicts actions until we hit action listen
        while (should_predict_another_action
               and self._should_handle_message(tracker)
               and num_predicted_actions < self.max_number_of_predictions):
            # this actually just calls the policy's method by the same name
            action, policy, confidence = self.predict_next_action(tracker)

            should_predict_another_action = self._run_action(action,
                                                             tracker,
                                                             dispatcher,
                                                             policy,
                                                             confidence)
            num_predicted_actions += 1

        if (num_predicted_actions == self.max_number_of_predictions and
                should_predict_another_action):
            # circuit breaker was tripped
            logger.warning(
                    "Circuit breaker tripped. Stopped predicting "
                    "more actions for sender '{}'".format(tracker.sender_id))
            if self.on_circuit_break:
                # call a registered callback
                self.on_circuit_break(tracker, dispatcher)
Ejemplo n.º 10
0
def test_restaurant_form():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, domain)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance
    entities = [{"entity": "cuisine", "value": "chinese"}]
    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 2
    assert isinstance(events[0], SlotSet)
    assert isinstance(events[1], SlotSet)

    assert events[0].key == "cuisine"
    assert events[0].value == "chinese"

    assert events[1].key == "requested_slot"
    assert events[1].value == "people"
Ejemplo n.º 11
0
def test_restaurant_form_skipahead():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, domain)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [{
        "entity": "cuisine",
        "value": "chinese"
    }, {
        "entity": "number",
        "value": 8
    }]
    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    print(events[0].as_story_string())
    print(events[1].as_story_string())
    assert len(events) == 3
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
Ejemplo n.º 12
0
def test_restaurant_form_unhappy_1():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, domain)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
    tracker.update(events[0])

    # second user utterance does not provide what's asked
    tracker.update(UserUttered("", intent={"name": "inform"}))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    print([(e.key, e.value) for e in events])
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)

    # same slot requested again
    assert events[0].key == "requested_slot"
    assert events[0].value == "cuisine"
Ejemplo n.º 13
0
    def _utter_error_and_roll_back(self, latest_bot_message, tracker, template):
        dispatcher = Dispatcher(latest_bot_message.sender_id,
                                latest_bot_message.output_channel,
                                self.domain)

        action = ActionInvalidUtterance(template)

        self._run_action(action, tracker, dispatcher)
Ejemplo n.º 14
0
def test_action():
    domain = Domain.load('domain.yml')
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg)
    uid = str(uuid.uuid1())
    tracker = DialogueStateTracker(uid, domain.slots)
    # print ("dispatcher,uid,tracker ===", dispatcher, uid, tracker)
    action = QuoraSearch()
    action.run(dispatcher, tracker, domain)
Ejemplo n.º 15
0
    def execute_action(self, sender_id: Text, action: Text,
                       output_channel: OutputChannel, policy: Text,
                       confidence: float) -> DialogueStateTracker:
        """Handle a single message."""

        processor = self.create_processor()
        dispatcher = Dispatcher(sender_id, output_channel, self.nlg)
        return processor.execute_action(sender_id, action, dispatcher, policy,
                                        confidence)
Ejemplo n.º 16
0
    def log_bot_utterances_on_tracker(tracker: DialogueStateTracker,
                                      dispatcher: Dispatcher) -> None:

        if dispatcher.latest_bot_messages:
            for m in dispatcher.latest_bot_messages:
                bot_utterance = BotUttered(text=m.text, data=m.data)
                logger.debug("Bot utterance '{}'".format(bot_utterance))
                tracker.update(bot_utterance)

            dispatcher.latest_bot_messages = []
Ejemplo n.º 17
0
def test_action():
    domain = Domain.load('domain.yml')
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg)
    uid = str(uuid.uuid1())
    tracker = DialogueStateTracker(uid, domain.slots)

    action = ActionJoke()
    action.run(dispatcher, tracker, domain)

    assert 'norris' in dispatcher.output_channel.latest_output()['text'].lower()
Ejemplo n.º 18
0
    def modified_predict_and_execute_next_action(self, message, tracker):
        # this will actually send the response to the user
        response = {}
        dispatcher = Dispatcher(message.sender_id, message.output_channel,
                                self.domain)
        # keep taking actions decided by the policy until it chooses to 'listen'
        should_predict_another_action = True
        num_predicted_actions = 0

        self._log_slots(tracker)

        # action loop. predicts actions until we hit action listen
        while (should_predict_another_action
               and self._should_handle_message(tracker)
               and num_predicted_actions < self.max_number_of_predictions):
            # this actually just calls the policy's method by the same name
            action = self._get_next_action(tracker)
            if action.name() != "action_listen":
                response["next_action"] = action.name()
            should_predict_another_action = self._run_action(
                action, tracker, dispatcher)
            num_predicted_actions += 1

        if (num_predicted_actions == self.max_number_of_predictions
                and should_predict_another_action):
            # circuit breaker was tripped
            logger.warn("Circuit breaker tripped. Stopped predicting "
                        "more actions for sender '{}'".format(
                            tracker.sender_id))
            if self.on_circuit_break:
                # call a registered callback
                self.on_circuit_break(tracker, dispatcher)
            response["tracker"] = tracker.current_state()

            # we used contexts and next action has changed so we need to change
            # intent accordingly
            tmp = response["tracker"]["latest_message"]["intent_ranking"][0][
                "name"]
            response["tracker"]["latest_message"]["intent_ranking"][0][
                "name"] = response["tracker"]["latest_message"][
                    "intent_ranking"][1]["name"]
            response["tracker"]["latest_message"]["intent_ranking"][1][
                "name"] = tmp
            response["tracker"]["latest_message"]["intent"] = response[
                "tracker"]["latest_message"]["intent_ranking"][0]

            # added to restart the bot when it becomes unstable when a followup action is triggered
            # due to contexts
            should_predict_another_action = self._run_action(
                ActionRestart(), tracker, dispatcher)
            return response

        response["tracker"] = tracker.current_state()
        return response
Ejemplo n.º 19
0
    def execute_action(
            self,
            sender_id,  # type: Text
            action,  # type: Text
            output_channel  # type: OutputChannel
    ):
        # type: (...) -> DialogueStateTracker
        """Handle a single message."""

        processor = self.create_processor()
        dispatcher = Dispatcher(sender_id, output_channel, self.nlg)
        return processor.execute_action(sender_id, action, dispatcher)
Ejemplo n.º 20
0
    def setUp(self):
        self.not_undestood = ActionNotUnderstood()
        # set Interpreter (NLU) to Rasa NLU
        self.interpreter = 'rasa-nlu/models/rasa-nlu/default/socialcompanionnlu'

        # load the trained agent model
        self.agent = Agent.load('./models/dialogue', self.interpreter)
        self.agent.handle_channel(ConsoleInputChannel())

        # TODO mock dispatcher, tracker and domain
        self.dispatcher = Dispatcher(output_channel=ConsoleOutputChannel())
        self.tracker = DialogueStateTracker()
        self.domain = Domain()
Ejemplo n.º 21
0
async def test_dispatcher_template_invalid_vars():
    templates = {
        "my_made_up_template": [{
            "text":
            "a template referencing an invalid {variable}."
        }]
    }
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    tracker = DialogueStateTracker("my-sender", slots=[])
    await dispatcher.utter_template("my_made_up_template", tracker)
    collected = dispatcher.output_channel.latest_output()
    assert collected['text'].startswith(
        "a template referencing an invalid {variable}.")
Ejemplo n.º 22
0
    def _set_and_execute_next_action(self, action_name, message, tracker):
        """
        Sets the next action to action_name then execute
        :param action_name: String
        :param tracker: DialogueStateTracker
        :return: void
        """
        # this will actually send the response to the user

        dispatcher = Dispatcher(message.sender_id, message.output_channel,
                                self.domain)

        action = self.domain.action_map[action_name][1]
        self._log_slots(tracker)
        self._run_action(action, tracker, dispatcher)
        self._run_action(ActionListen(), tracker, dispatcher)
Ejemplo n.º 23
0
def test_restaurant_form_unhappy_2():
    domain = TemplateDomain.load("data/test_domains/restaurant_form.yml")
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-restaurant"
    dispatcher = Dispatcher(sender_id, out, nlg)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # first user utterance
    entities = [{
        "entity": "cuisine",
        "value": "chinese"
    }, {
        "entity": "number",
        "value": 8
    }]

    tracker.update(
        UserUttered("", intent={"name": "inform"}, entities=entities))

    # store all entities as slots
    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)

    for e in events:
        tracker.update(e)

    cuisine = tracker.get_slot("cuisine")
    people = tracker.get_slot("people")
    assert cuisine == "chinese"
    assert people == 8

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    assert len(events) == 3
    assert isinstance(events[0], SlotSet)
    assert events[2].key == "requested_slot"
    assert events[2].value == "vegetarian"
    tracker.update(events[2])

    # second user utterance does not provide what's asked
    tracker.update(UserUttered("", intent={"name": "random"}))

    events = ActionSearchRestaurants().run(dispatcher, tracker, domain)
    s = events[0].as_story_string()
    assert len(events) == 1
    assert events[0].key == "requested_slot"
    assert events[0].value == "vegetarian"
Ejemplo n.º 24
0
def test_reminder_aborted(default_processor):
    out = CollectingOutputChannel()
    sender_id = uuid.uuid4().hex

    d = Dispatcher(sender_id, out, default_processor.nlg)
    r = ReminderScheduled("utter_greet", datetime.datetime.now(),
                          kill_on_user_message=True)
    t = default_processor.tracker_store.get_or_create_tracker(sender_id)

    t.update(r)
    t.update(UserUttered("test"))  # cancels the reminder

    default_processor.tracker_store.save(t)
    default_processor.handle_reminder(r, d)

    # retrieve the updated tracker
    t = default_processor.tracker_store.retrieve(sender_id)
    assert len(t.events) == 3  # nothing should have been executed
Ejemplo n.º 25
0
def test_query_form_set_username_directly():
    domain = TemplateDomain.load("data/test_domains/query_form.yml")
    tracker_store = InMemoryTrackerStore(domain)
    out = CollectingOutputChannel()
    sender_id = "test-form"
    dispatcher = Dispatcher(sender_id, out, domain)
    tracker = tracker_store.get_or_create_tracker(sender_id)

    # pre-fill username slot
    username = "******"
    tracker.update(SlotSet('username', username))

    # first user utterance
    tracker.update(UserUttered("", intent={"name": "inform"}))
    events = ActionSearchQuery().run(dispatcher, tracker, domain)
    last_message = dispatcher.latest_bot_messages[-1]
    assert len(events) == 1
    assert isinstance(events[0], SlotSet)
    assert events[0].key == "requested_slot"
    assert events[0].value == "query"
    assert username in last_message.text
Ejemplo n.º 26
0
    def predict_and_execute_next_action(self, message, tracker):
        dispatcher = Dispatcher(message.sender_id, message.output_channel,
                                self.message_processor.nlg)
        # keep taking actions decided by the policy until it chooses to 'listen'
        should_predict_another_action = True
        num_predicted_actions = 0

        self.log_slots(tracker)
        # action loop. predicts actions until we hit action listen
        while (should_predict_another_action
               and self.should_handle_message(tracker)
               and num_predicted_actions <
               self.message_processor.max_number_of_predictions):
            # this actually just calls the policy's method by the same name
            probabilities, policy = self.message_processor._get_next_action_probabilities(
                tracker)
            max_index = int(np.argmax(probabilities))
            if self.message_processor.domain.num_actions <= max_index or max_index < 0:
                raise Exception("Can not access action at index {}. "
                                "Domain has {} actions.".format(
                                    max_index,
                                    self.message_processor.domain.num_actions))

            action = self.ask_for_action(
                self.message_processor.domain.action_names[max_index],
                self.message_processor.action_endpoint)
            confidence = probabilities[max_index]
            # action, policy, confidence = self.agent.predict_next_action(tracker)

            should_predict_another_action = self.run_action(
                action, tracker, dispatcher, policy, confidence)
            num_predicted_actions += 1

        if (num_predicted_actions
                == self.message_processor.max_number_of_predictions
                and should_predict_another_action):
            # circuit breaker was tripped
            if self.message_processor.on_circuit_break:
                # call a registered callback
                self.message_processor.on_circuit_break(tracker, dispatcher)
Ejemplo n.º 27
0
def test_reminder_scheduled(default_processor):
    out = CollectingOutputChannel()
    sender_id = uuid.uuid4().hex

    d = Dispatcher(sender_id, out, default_processor.nlg)
    r = ReminderScheduled("utter_greet", datetime.datetime.now())
    t = default_processor.tracker_store.get_or_create_tracker(sender_id)

    t.update(UserUttered("test"))
    t.update(ActionExecuted("action_reminder_reminder"))
    t.update(r)

    default_processor.tracker_store.save(t)
    default_processor.handle_reminder(r, d)

    # retrieve the updated tracker
    t = default_processor.tracker_store.retrieve(sender_id)
    assert t.events[-4] == UserUttered(None)
    assert t.events[-3] == ActionExecuted("utter_greet")
    assert t.events[-2] == BotUttered("hey there None!", {'elements': None,
                                                          'buttons': None,
                                                          'attachment': None})
    assert t.events[-1] == ActionExecuted("action_listen")
Ejemplo n.º 28
0
def default_dispatcher_collecting(default_nlg):
    bot = CollectingOutputChannel()
    return Dispatcher("my-sender", bot, default_nlg)
Ejemplo n.º 29
0
def default_dispatcher_cmd(default_domain):
    bot = ConsoleOutputChannel()
    return Dispatcher("my-sender", bot, default_domain)