示例#1
0
    async def _predict_and_execute_next_action(self, message, tracker):
        # keep taking actions decided by the policy until it chooses to 'listen'
        should_predict_another_action = True
        num_predicted_actions = 0

        def is_action_limit_reached():
            return (num_predicted_actions == self.max_number_of_predictions
                    and should_predict_another_action)

        # this will actually send the response to the user
        dispatcher = Dispatcher(message.sender_id, message.output_channel,
                                self.nlg)

        # action loop. predicts actions until we hit action listen
        while (should_predict_another_action
               and self._should_handle_message(tracker)
               and num_predicted_actions < self.max_number_of_predictions):
            # this actually just calls the policy's method by the same name
            action, policy, confidence = self.predict_next_action(tracker)

            should_predict_another_action = await self._run_action(
                action, tracker, dispatcher, policy, confidence)
            num_predicted_actions += 1

        if is_action_limit_reached():
            # circuit breaker was tripped
            logger.warning("Circuit breaker tripped. Stopped predicting "
                           "more actions for sender '{}'".format(
                               tracker.sender_id))
            if self.on_circuit_break:
                # call a registered callback
                self.on_circuit_break(tracker, dispatcher)
示例#2
0
async def test_reminder_scheduled(default_processor):
    out = CollectingOutputChannel()
    sender_id = uuid.uuid4().hex

    d = Dispatcher(sender_id, out, default_processor.nlg)
    r = ReminderScheduled("utter_greet", datetime.datetime.now())
    t = default_processor.tracker_store.get_or_create_tracker(sender_id)

    t.update(UserUttered("test"))
    t.update(ActionExecuted("action_reminder_reminder"))
    t.update(r)

    default_processor.tracker_store.save(t)
    await default_processor.handle_reminder(r, d)

    # retrieve the updated tracker
    t = default_processor.tracker_store.retrieve(sender_id)
    assert t.events[-4] == UserUttered(None)
    assert t.events[-3] == ActionExecuted("utter_greet")
    assert t.events[-2] == BotUttered("hey there None!", {
        'elements': None,
        'buttons': None,
        'attachment': None
    })
    assert t.events[-1] == ActionExecuted("action_listen")
示例#3
0
文件: agent.py 项目: tgalery/rasa_nlu
    async def execute_action(self, sender_id: Text, action: Text,
                             output_channel: OutputChannel, policy: Text,
                             confidence: float) -> DialogueStateTracker:
        """Handle a single message."""

        processor = self.create_processor()
        dispatcher = Dispatcher(sender_id, output_channel, self.nlg)
        return await processor.execute_action(sender_id, action, dispatcher,
                                              policy, confidence)
示例#4
0
    def log_bot_utterances_on_tracker(tracker: DialogueStateTracker,
                                      dispatcher: Dispatcher) -> None:

        if dispatcher.latest_bot_messages:
            for m in dispatcher.latest_bot_messages:
                bot_utterance = BotUttered(text=m.text, data=m.data)
                logger.debug("Bot utterance '{}'".format(bot_utterance))
                tracker.update(bot_utterance)

            dispatcher.latest_bot_messages = []
示例#5
0
async def test_dispatcher_utter_buttons_from_domain_templ(default_tracker):
    domain_file = "examples/moodbot/domain.yml"
    domain = Domain.load(domain_file)
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    await dispatcher.utter_template("utter_greet", default_tracker)
    assert len(bot.messages) == 1
    assert bot.messages[0]["text"] == "Hey! How are you?"
    assert bot.messages[0]["buttons"] == [
        {"payload": "great", "title": "great"},
        {"payload": "super sad", "title": "super sad"},
    ]
示例#6
0
async def test_dispatcher_template_invalid_vars():
    templates = {
        "my_made_up_template": [
            {"text": "a template referencing an invalid {variable}."}
        ]
    }
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    tracker = DialogueStateTracker("my-sender", slots=[])
    await dispatcher.utter_template("my_made_up_template", tracker)
    collected = dispatcher.output_channel.latest_output()
    assert collected["text"].startswith("a template referencing an invalid {variable}.")
示例#7
0
async def test_reminder_aborted(default_processor):
    out = CollectingOutputChannel()
    sender_id = uuid.uuid4().hex

    d = Dispatcher(sender_id, out, default_processor.nlg)
    r = ReminderScheduled("utter_greet",
                          datetime.datetime.now(),
                          kill_on_user_message=True)
    t = default_processor.tracker_store.get_or_create_tracker(sender_id)

    t.update(r)
    t.update(UserUttered("test"))  # cancels the reminder

    default_processor.tracker_store.save(t)
    await default_processor.handle_reminder(r, d)

    # retrieve the updated tracker
    t = default_processor.tracker_store.retrieve(sender_id)
    assert len(t.events) == 3  # nothing should have been executed
示例#8
0
async def test_reminder_cancelled(default_processor):
    out = CollectingOutputChannel()
    sender_ids = [uuid.uuid4().hex, uuid.uuid4().hex]
    trackers = []
    for sender_id in sender_ids:
        t = default_processor.tracker_store.get_or_create_tracker(sender_id)

        t.update(UserUttered("test"))
        t.update(ActionExecuted("action_reminder_reminder"))
        t.update(
            ReminderScheduled(
                "utter_greet", datetime.datetime.now(), kill_on_user_message=True
            )
        )
        trackers.append(t)

    # cancel reminder for the first user
    trackers[0].update(ReminderCancelled("utter_greet"))

    for i, t in enumerate(trackers):
        default_processor.tracker_store.save(t)
        d = Dispatcher(sender_id, out, default_processor.nlg)
        await default_processor._schedule_reminders(t.events, t, d)
    # check that the jobs were added
    assert len((await jobs.scheduler()).get_jobs()) == 2

    for t in trackers:
        await default_processor._cancel_reminders(t.events, t)
    # check that only one job was removed
    assert len((await jobs.scheduler()).get_jobs()) == 1

    # execute the jobs
    await asyncio.sleep(3)

    t = default_processor.tracker_store.retrieve(sender_ids[0])
    # there should be no utter_greet action
    assert ActionExecuted("utter_greet") not in t.events

    t = default_processor.tracker_store.retrieve(sender_ids[1])
    # there should be utter_greet action
    assert ActionExecuted("utter_greet") in t.events
示例#9
0
文件: conftest.py 项目: rajanala/rasa
def default_dispatcher_collecting(default_nlg):
    bot = CollectingOutputChannel()
    return Dispatcher("my-sender", bot, default_nlg)