Ejemplo n.º 1
0
async def test_update_tracker_session_with_metadata(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex
    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    # patch `_has_session_expired()` so the `_update_tracker_session()` call actually
    # does something
    monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True)

    metadata = {"metadataTestKey": "metadataTestValue"}

    await default_processor._update_tracker_session(tracker, default_channel, metadata)

    # the save is not called in _update_tracker_session()
    default_processor._save_tracker(tracker)

    # inspect tracker events and make sure SessionStarted event is present
    # and has metadata.
    tracker = default_processor.tracker_store.retrieve(sender_id)
    assert tracker.events.count(SessionStarted()) == 1

    session_started_event_idx = tracker.events.index(SessionStarted())
    session_started_event_metadata = tracker.events[session_started_event_idx].metadata

    assert session_started_event_metadata == metadata
Ejemplo n.º 2
0
def test_tracker_store_retrieve_with_session_started_events(
    tracker_store_type: Type[TrackerStore],
    tracker_store_kwargs: Dict,
    default_domain: Domain,
):
    tracker_store = tracker_store_type(default_domain, **tracker_store_kwargs)
    events = [
        UserUttered("Hola", {"name": "greet"}),
        BotUttered("Hi"),
        SessionStarted(),
        UserUttered("Ciao", {"name": "greet"}),
    ]
    sender_id = "test_sql_tracker_store_with_session_events"
    tracker = DialogueStateTracker.from_events(sender_id, events)
    tracker_store.save(tracker)

    # Save other tracker to ensure that we don't run into problems with other senders
    other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()])
    tracker_store.save(other_tracker)

    # Retrieve tracker with events since latest SessionStarted
    tracker = tracker_store.retrieve(sender_id)

    assert len(tracker.events) == 2
    assert all((event == tracker.events[i] for i, event in enumerate(events[2:])))
Ejemplo n.º 3
0
def _saved_tracker_with_multiple_session_starts(
        tracker_store: TrackerStore, sender_id: Text) -> DialogueStateTracker:
    tracker = DialogueStateTracker.from_events(
        sender_id,
        [
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
            UserUttered("hi"),
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
        ],
    )

    tracker_store.save(tracker)
    return tracker_store.retrieve(sender_id)
Ejemplo n.º 4
0
async def test_update_tracker_session(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex
    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    # patch `_has_session_expired()` so the `_update_tracker_session()` call actually
    # does something
    monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True)

    await default_processor._update_tracker_session(tracker, default_channel)

    # the save is not called in _update_tracker_session()
    default_processor._save_tracker(tracker)

    # inspect tracker and make sure all events are present
    tracker = default_processor.tracker_store.retrieve(sender_id)

    assert list(tracker.events) == [
        ActionExecuted(ACTION_LISTEN_NAME),
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
    ]
Ejemplo n.º 5
0
    async def _update_tracker_session(
        self,
        tracker: DialogueStateTracker,
        output_channel: OutputChannel,
        metadata: Optional[Dict] = None,
    ) -> None:
        """Check the current session in `tracker` and update it if expired.

        An 'action_session_start' is run if the latest tracker session has expired,
        or if the tracker does not yet contain any events (only those after the last
        restart are considered).

        Args:
            metadata: Data sent from client associated with the incoming user message.
            tracker: Tracker to inspect.
            output_channel: Output channel for potential utterances in a custom
                `ActionSessionStart`.
        """
        if not tracker.applied_events() or self._has_session_expired(tracker):
            logger.debug(
                f"Starting a new session for conversation ID '{tracker.sender_id}'."
            )

            if metadata:
                tracker.events.append(SessionStarted(metadata=metadata))

            await self._run_action(
                action=self._get_action(ACTION_SESSION_START_NAME,
                                        tracker),  # bf
                tracker=tracker,
                output_channel=output_channel,
                nlg=self.nlg,
            )
Ejemplo n.º 6
0
def test_session_start(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    # add a SessionStarted event
    tracker.update(SessionStarted())

    # tracker has one event
    assert len(tracker.events) == 1
Ejemplo n.º 7
0
async def test_action_session_start_without_slots(
    default_channel: CollectingOutputChannel,
    template_nlg: TemplatedNaturalLanguageGenerator,
    template_sender_tracker: DialogueStateTracker,
    default_domain: Domain,
):
    events = await ActionSessionStart().run(default_channel, template_nlg,
                                            template_sender_tracker,
                                            default_domain)
    assert events == [SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME)]
Ejemplo n.º 8
0
def test_tracker_store_retrieve_with_events_from_previous_sessions(
        tracker_store_type: Type[TrackerStore], tracker_store_kwargs: Dict):
    tracker_store = tracker_store_type(Domain.empty(), **tracker_store_kwargs)
    tracker_store.load_events_from_previous_conversation_sessions = True

    conversation_id = uuid.uuid4().hex
    tracker = DialogueStateTracker.from_events(
        conversation_id,
        [
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
            UserUttered("hi"),
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
        ],
    )
    tracker_store.save(tracker)

    actual = tracker_store.retrieve(conversation_id)

    assert len(actual.events) == len(tracker.events)
Ejemplo n.º 9
0
async def test_update_tracker_session_with_slots(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex
    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    # apply a user uttered and five slots
    user_event = UserUttered("some utterance")
    tracker.update(user_event)

    slot_set_events = [
        SlotSet(f"slot key {i}", f"test value {i}") for i in range(5)
    ]

    for event in slot_set_events:
        tracker.update(event)

    # patch `_has_session_expired()` so the `_update_tracker_session()` call actually
    # does something
    monkeypatch.setattr(default_processor, "_has_session_expired",
                        lambda _: True)

    await default_processor._update_tracker_session(tracker, default_channel)

    # the save is not called in _update_tracker_session()
    default_processor._save_tracker(tracker)

    # inspect tracker and make sure all events are present
    tracker = default_processor.tracker_store.retrieve(sender_id)
    events = list(tracker.events)

    # the first three events should be up to the user utterance
    assert events[:2] == [
        ActionExecuted(ACTION_LISTEN_NAME),
        user_event,
    ]

    # next come the five slots
    assert events[2:7] == slot_set_events

    # the next two events are the session start sequence
    assert events[7:9] == [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted()
    ]

    # the five slots should be reapplied
    assert events[9:14] == slot_set_events

    # finally an action listen, this should also be the last event
    assert events[14] == events[-1] == ActionExecuted(ACTION_LISTEN_NAME)
Ejemplo n.º 10
0
async def test_get_tracker_with_session_start(
        default_channel: CollectingOutputChannel,
        default_processor: MessageProcessor):
    sender_id = uuid.uuid4().hex
    tracker = await default_processor.get_tracker_with_session_start(
        sender_id, default_channel)

    # ensure session start sequence is present
    assert list(tracker.events) == [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
    ]
Ejemplo n.º 11
0
    async def run(
        self,
        output_channel: "OutputChannel",
        nlg: "NaturalLanguageGenerator",
        tracker: "DialogueStateTracker",
        domain: "Domain",
    ) -> List[Event]:
        _events = [SessionStarted(metadata=self.metadata)]

        if domain.session_config.carry_over_slots:
            _events.extend(self._slot_set_events_from_tracker(tracker))

        _events.append(ActionExecuted(ACTION_LISTEN_NAME))

        return _events
Ejemplo n.º 12
0
def test_session_start_is_not_serialised(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    # add SlotSet event
    tracker.update(SlotSet("slot", "value"))

    # add the two SessionStarted events and a user event
    tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
    tracker.update(SessionStarted())
    tracker.update(UserUttered("say something"))

    # make sure session start is not serialised
    story = Story.from_events(tracker.events, "some-story01")

    expected = """## some-story01
    - slot{"slot": "value"}
* say something
"""

    assert story.as_story_string(flat=True) == expected
Ejemplo n.º 13
0
def test_fetch_events_within_time_range_with_session_events():
    conversation_id = "test_fetch_events_within_time_range_with_sessions"

    tracker_store = SQLTrackerStore(
        dialect="sqlite", db=f"{uuid.uuid4().hex}.db", domain=Domain.empty()
    )

    events = [
        random_user_uttered_event(1),
        SessionStarted(2),
        ActionExecuted(timestamp=3, action_name=ACTION_SESSION_START_NAME),
        random_user_uttered_event(4),
    ]
    tracker = DialogueStateTracker.from_events(conversation_id, evts=events)
    tracker_store.save(tracker)

    exporter = MockExporter(tracker_store=tracker_store)

    # noinspection PyProtectedMember
    fetched_events = exporter._fetch_events_within_time_range()

    assert len(fetched_events) == len(events)
Ejemplo n.º 14
0
def test_json_parse_session_started():
    evt = {"event": "session_started"}
    assert Event.from_parameters(evt) == SessionStarted()
Ejemplo n.º 15
0
    template_sender_tracker: DialogueStateTracker,
    default_domain: Domain,
):
    events = await ActionSessionStart().run(default_channel, template_nlg,
                                            template_sender_tracker,
                                            default_domain)
    assert events == [SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME)]


@pytest.mark.parametrize(
    "session_config, expected_events",
    [
        (
            SessionConfig(123, True),
            [
                SessionStarted(),
                SlotSet("my_slot", "value"),
                SlotSet("another-slot", "value2"),
                ActionExecuted(action_name=ACTION_LISTEN_NAME),
            ],
        ),
        (
            SessionConfig(123, False),
            [SessionStarted(),
             ActionExecuted(action_name=ACTION_LISTEN_NAME)],
        ),
    ],
)
async def test_action_session_start_with_slots(
    default_channel: CollectingOutputChannel,
    template_nlg: TemplatedNaturalLanguageGenerator,
Ejemplo n.º 16
0
async def test_handle_message_with_session_start(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex

    entity = "name"
    slot_1 = {entity: "Core"}
    await default_processor.handle_message(
        UserMessage(f"/greet{json.dumps(slot_1)}", default_channel, sender_id)
    )

    assert default_channel.latest_output() == {
        "recipient_id": sender_id,
        "text": "hey there Core!",
    }

    # patch processor so a session start is triggered
    monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True)

    slot_2 = {entity: "post-session start hello"}
    # handle a new message
    await default_processor.handle_message(
        UserMessage(f"/greet{json.dumps(slot_2)}", default_channel, sender_id)
    )

    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    # make sure the sequence of events is as expected
    assert list(tracker.events) == [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(
            f"/greet{json.dumps(slot_1)}",
            {INTENT_NAME_KEY: "greet", "confidence": 1.0},
            [{"entity": entity, "start": 6, "end": 22, "value": "Core"}],
        ),
        SlotSet(entity, slot_1[entity]),
        ActionExecuted("utter_greet"),
        BotUttered("hey there Core!", metadata={"template_name": "utter_greet"}),
        ActionExecuted(ACTION_LISTEN_NAME),
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        # the initial SlotSet is reapplied after the SessionStarted sequence
        SlotSet(entity, slot_1[entity]),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(
            f"/greet{json.dumps(slot_2)}",
            {INTENT_NAME_KEY: "greet", "confidence": 1.0},
            [
                {
                    "entity": entity,
                    "start": 6,
                    "end": 42,
                    "value": "post-session start hello",
                }
            ],
        ),
        SlotSet(entity, slot_2[entity]),
        ActionExecuted(ACTION_LISTEN_NAME),
    ]
Ejemplo n.º 17
0
def test_session_started_event_is_not_serialised():
    assert SessionStarted().as_story_string() is None
Ejemplo n.º 18
0
def test_json_parse_session_started():
    # DOCS MARKER SessionStarted
    evt = {"event": "session_started"}
    # DOCS END
    assert Event.from_parameters(evt) == SessionStarted()
Ejemplo n.º 19
0
             "confidence": 1.0
         }, []),
         UserUttered("/goodbye", {
             "name": "goodbye",
             "confidence": 1.0
         }, []),
     ),
     (SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")),
     (Restarted(), None),
     (AllSlotsReset(), None),
     (ConversationPaused(), None),
     (ConversationResumed(), None),
     (StoryExported(), None),
     (ActionReverted(), None),
     (UserUtteranceReverted(), None),
     (SessionStarted(), None),
     (ActionExecuted("my_action"), ActionExecuted("my_other_action")),
     (FollowupAction("my_action"), FollowupAction("my_other_action")),
     (
         BotUttered("my_text", {"my_data": 1}),
         BotUttered("my_other_test", {"my_other_data": 1}),
     ),
     (
         AgentUttered("my_text", "my_data"),
         AgentUttered("my_other_test", "my_other_data"),
     ),
     (
         ReminderScheduled("my_action", datetime.now()),
         ReminderScheduled("my_other_action", datetime.now()),
     ),
 ],