async def test_update_tracker_session_with_metadata( default_channel: CollectingOutputChannel, default_processor: MessageProcessor, monkeypatch: MonkeyPatch, ): sender_id = uuid.uuid4().hex tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) # patch `_has_session_expired()` so the `_update_tracker_session()` call actually # does something monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True) metadata = {"metadataTestKey": "metadataTestValue"} await default_processor._update_tracker_session(tracker, default_channel, metadata) # the save is not called in _update_tracker_session() default_processor._save_tracker(tracker) # inspect tracker events and make sure SessionStarted event is present # and has metadata. tracker = default_processor.tracker_store.retrieve(sender_id) assert tracker.events.count(SessionStarted()) == 1 session_started_event_idx = tracker.events.index(SessionStarted()) session_started_event_metadata = tracker.events[session_started_event_idx].metadata assert session_started_event_metadata == metadata
def test_tracker_store_retrieve_with_session_started_events( tracker_store_type: Type[TrackerStore], tracker_store_kwargs: Dict, default_domain: Domain, ): tracker_store = tracker_store_type(default_domain, **tracker_store_kwargs) events = [ UserUttered("Hola", {"name": "greet"}), BotUttered("Hi"), SessionStarted(), UserUttered("Ciao", {"name": "greet"}), ] sender_id = "test_sql_tracker_store_with_session_events" tracker = DialogueStateTracker.from_events(sender_id, events) tracker_store.save(tracker) # Save other tracker to ensure that we don't run into problems with other senders other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()]) tracker_store.save(other_tracker) # Retrieve tracker with events since latest SessionStarted tracker = tracker_store.retrieve(sender_id) assert len(tracker.events) == 2 assert all((event == tracker.events[i] for i, event in enumerate(events[2:])))
def _saved_tracker_with_multiple_session_starts( tracker_store: TrackerStore, sender_id: Text) -> DialogueStateTracker: tracker = DialogueStateTracker.from_events( sender_id, [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), UserUttered("hi"), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ], ) tracker_store.save(tracker) return tracker_store.retrieve(sender_id)
async def test_update_tracker_session( default_channel: CollectingOutputChannel, default_processor: MessageProcessor, monkeypatch: MonkeyPatch, ): sender_id = uuid.uuid4().hex tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) # patch `_has_session_expired()` so the `_update_tracker_session()` call actually # does something monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True) await default_processor._update_tracker_session(tracker, default_channel) # the save is not called in _update_tracker_session() default_processor._save_tracker(tracker) # inspect tracker and make sure all events are present tracker = default_processor.tracker_store.retrieve(sender_id) assert list(tracker.events) == [ ActionExecuted(ACTION_LISTEN_NAME), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), ]
async def _update_tracker_session( self, tracker: DialogueStateTracker, output_channel: OutputChannel, metadata: Optional[Dict] = None, ) -> None: """Check the current session in `tracker` and update it if expired. An 'action_session_start' is run if the latest tracker session has expired, or if the tracker does not yet contain any events (only those after the last restart are considered). Args: metadata: Data sent from client associated with the incoming user message. tracker: Tracker to inspect. output_channel: Output channel for potential utterances in a custom `ActionSessionStart`. """ if not tracker.applied_events() or self._has_session_expired(tracker): logger.debug( f"Starting a new session for conversation ID '{tracker.sender_id}'." ) if metadata: tracker.events.append(SessionStarted(metadata=metadata)) await self._run_action( action=self._get_action(ACTION_SESSION_START_NAME, tracker), # bf tracker=tracker, output_channel=output_channel, nlg=self.nlg, )
def test_session_start(default_domain: Domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 # add a SessionStarted event tracker.update(SessionStarted()) # tracker has one event assert len(tracker.events) == 1
async def test_action_session_start_without_slots( default_channel: CollectingOutputChannel, template_nlg: TemplatedNaturalLanguageGenerator, template_sender_tracker: DialogueStateTracker, default_domain: Domain, ): events = await ActionSessionStart().run(default_channel, template_nlg, template_sender_tracker, default_domain) assert events == [SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME)]
def test_tracker_store_retrieve_with_events_from_previous_sessions( tracker_store_type: Type[TrackerStore], tracker_store_kwargs: Dict): tracker_store = tracker_store_type(Domain.empty(), **tracker_store_kwargs) tracker_store.load_events_from_previous_conversation_sessions = True conversation_id = uuid.uuid4().hex tracker = DialogueStateTracker.from_events( conversation_id, [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), UserUttered("hi"), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ], ) tracker_store.save(tracker) actual = tracker_store.retrieve(conversation_id) assert len(actual.events) == len(tracker.events)
async def test_update_tracker_session_with_slots( default_channel: CollectingOutputChannel, default_processor: MessageProcessor, monkeypatch: MonkeyPatch, ): sender_id = uuid.uuid4().hex tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) # apply a user uttered and five slots user_event = UserUttered("some utterance") tracker.update(user_event) slot_set_events = [ SlotSet(f"slot key {i}", f"test value {i}") for i in range(5) ] for event in slot_set_events: tracker.update(event) # patch `_has_session_expired()` so the `_update_tracker_session()` call actually # does something monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True) await default_processor._update_tracker_session(tracker, default_channel) # the save is not called in _update_tracker_session() default_processor._save_tracker(tracker) # inspect tracker and make sure all events are present tracker = default_processor.tracker_store.retrieve(sender_id) events = list(tracker.events) # the first three events should be up to the user utterance assert events[:2] == [ ActionExecuted(ACTION_LISTEN_NAME), user_event, ] # next come the five slots assert events[2:7] == slot_set_events # the next two events are the session start sequence assert events[7:9] == [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted() ] # the five slots should be reapplied assert events[9:14] == slot_set_events # finally an action listen, this should also be the last event assert events[14] == events[-1] == ActionExecuted(ACTION_LISTEN_NAME)
async def test_get_tracker_with_session_start( default_channel: CollectingOutputChannel, default_processor: MessageProcessor): sender_id = uuid.uuid4().hex tracker = await default_processor.get_tracker_with_session_start( sender_id, default_channel) # ensure session start sequence is present assert list(tracker.events) == [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), ]
async def run( self, output_channel: "OutputChannel", nlg: "NaturalLanguageGenerator", tracker: "DialogueStateTracker", domain: "Domain", ) -> List[Event]: _events = [SessionStarted(metadata=self.metadata)] if domain.session_config.carry_over_slots: _events.extend(self._slot_set_events_from_tracker(tracker)) _events.append(ActionExecuted(ACTION_LISTEN_NAME)) return _events
def test_session_start_is_not_serialised(default_domain: Domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 # add SlotSet event tracker.update(SlotSet("slot", "value")) # add the two SessionStarted events and a user event tracker.update(ActionExecuted(ACTION_SESSION_START_NAME)) tracker.update(SessionStarted()) tracker.update(UserUttered("say something")) # make sure session start is not serialised story = Story.from_events(tracker.events, "some-story01") expected = """## some-story01 - slot{"slot": "value"} * say something """ assert story.as_story_string(flat=True) == expected
def test_fetch_events_within_time_range_with_session_events(): conversation_id = "test_fetch_events_within_time_range_with_sessions" tracker_store = SQLTrackerStore( dialect="sqlite", db=f"{uuid.uuid4().hex}.db", domain=Domain.empty() ) events = [ random_user_uttered_event(1), SessionStarted(2), ActionExecuted(timestamp=3, action_name=ACTION_SESSION_START_NAME), random_user_uttered_event(4), ] tracker = DialogueStateTracker.from_events(conversation_id, evts=events) tracker_store.save(tracker) exporter = MockExporter(tracker_store=tracker_store) # noinspection PyProtectedMember fetched_events = exporter._fetch_events_within_time_range() assert len(fetched_events) == len(events)
def test_json_parse_session_started(): evt = {"event": "session_started"} assert Event.from_parameters(evt) == SessionStarted()
template_sender_tracker: DialogueStateTracker, default_domain: Domain, ): events = await ActionSessionStart().run(default_channel, template_nlg, template_sender_tracker, default_domain) assert events == [SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME)] @pytest.mark.parametrize( "session_config, expected_events", [ ( SessionConfig(123, True), [ SessionStarted(), SlotSet("my_slot", "value"), SlotSet("another-slot", "value2"), ActionExecuted(action_name=ACTION_LISTEN_NAME), ], ), ( SessionConfig(123, False), [SessionStarted(), ActionExecuted(action_name=ACTION_LISTEN_NAME)], ), ], ) async def test_action_session_start_with_slots( default_channel: CollectingOutputChannel, template_nlg: TemplatedNaturalLanguageGenerator,
async def test_handle_message_with_session_start( default_channel: CollectingOutputChannel, default_processor: MessageProcessor, monkeypatch: MonkeyPatch, ): sender_id = uuid.uuid4().hex entity = "name" slot_1 = {entity: "Core"} await default_processor.handle_message( UserMessage(f"/greet{json.dumps(slot_1)}", default_channel, sender_id) ) assert default_channel.latest_output() == { "recipient_id": sender_id, "text": "hey there Core!", } # patch processor so a session start is triggered monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True) slot_2 = {entity: "post-session start hello"} # handle a new message await default_processor.handle_message( UserMessage(f"/greet{json.dumps(slot_2)}", default_channel, sender_id) ) tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) # make sure the sequence of events is as expected assert list(tracker.events) == [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), UserUttered( f"/greet{json.dumps(slot_1)}", {INTENT_NAME_KEY: "greet", "confidence": 1.0}, [{"entity": entity, "start": 6, "end": 22, "value": "Core"}], ), SlotSet(entity, slot_1[entity]), ActionExecuted("utter_greet"), BotUttered("hey there Core!", metadata={"template_name": "utter_greet"}), ActionExecuted(ACTION_LISTEN_NAME), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), # the initial SlotSet is reapplied after the SessionStarted sequence SlotSet(entity, slot_1[entity]), ActionExecuted(ACTION_LISTEN_NAME), UserUttered( f"/greet{json.dumps(slot_2)}", {INTENT_NAME_KEY: "greet", "confidence": 1.0}, [ { "entity": entity, "start": 6, "end": 42, "value": "post-session start hello", } ], ), SlotSet(entity, slot_2[entity]), ActionExecuted(ACTION_LISTEN_NAME), ]
def test_session_started_event_is_not_serialised(): assert SessionStarted().as_story_string() is None
def test_json_parse_session_started(): # DOCS MARKER SessionStarted evt = {"event": "session_started"} # DOCS END assert Event.from_parameters(evt) == SessionStarted()
"confidence": 1.0 }, []), UserUttered("/goodbye", { "name": "goodbye", "confidence": 1.0 }, []), ), (SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")), (Restarted(), None), (AllSlotsReset(), None), (ConversationPaused(), None), (ConversationResumed(), None), (StoryExported(), None), (ActionReverted(), None), (UserUtteranceReverted(), None), (SessionStarted(), None), (ActionExecuted("my_action"), ActionExecuted("my_other_action")), (FollowupAction("my_action"), FollowupAction("my_other_action")), ( BotUttered("my_text", {"my_data": 1}), BotUttered("my_other_test", {"my_other_data": 1}), ), ( AgentUttered("my_text", "my_data"), AgentUttered("my_other_test", "my_other_data"), ), ( ReminderScheduled("my_action", datetime.now()), ReminderScheduled("my_other_action", datetime.now()), ), ],