def test_merge_yaml_domains(): test_yaml_1 = f"""config: store_entities_as_slots: true entities: [] intents: [] slots: {{}} responses: utter_greet: - text: hey there! {KEY_E2E_ACTIONS}: - Hi""" test_yaml_2 = f"""config: store_entities_as_slots: false session_config: session_expiration_time: 20 carry_over_slots: true entities: - cuisine intents: - greet slots: cuisine: type: text {KEY_E2E_ACTIONS}: - Bye responses: utter_goodbye: - text: bye! utter_greet: - text: hey you!""" domain_1 = Domain.from_yaml(test_yaml_1) domain_2 = Domain.from_yaml(test_yaml_2) domain = domain_1.merge(domain_2) # single attribute should be taken from domain_1 assert domain.store_entities_as_slots # conflicts should be taken from domain_1 assert domain.templates == { "utter_greet": [{"text": "hey there!"}], "utter_goodbye": [{"text": "bye!"}], } # lists should be deduplicated and merged assert domain.intents == sorted(["greet", *DEFAULT_INTENTS]) assert domain.entities == ["cuisine"] assert isinstance(domain.slots[0], TextSlot) assert domain.slots[0].name == "cuisine" assert sorted(domain.user_actions) == sorted(["utter_greet", "utter_goodbye"]) assert domain.session_config == SessionConfig(20, True) domain = domain_1.merge(domain_2, override=True) # single attribute should be taken from domain_2 assert not domain.store_entities_as_slots # conflicts should take value from domain_2 assert domain.templates == { "utter_greet": [{"text": "hey you!"}], "utter_goodbye": [{"text": "bye!"}], } assert domain.session_config == SessionConfig(20, True) assert domain.action_texts == ["Bye", "Hi"]
def test_domain_as_dict_with_session_config(): session_config = SessionConfig(123, False) domain = Domain.empty() domain.session_config = session_config serialized = domain.as_dict() deserialized = Domain.from_dict(serialized) assert deserialized.session_config == session_config
async def test_get_story_does_not_update_conversation_session( rasa_app: SanicASGITestClient, monkeypatch: MonkeyPatch ): conversation_id = "some-conversation-ID" # domain with short session expiration time of one second domain = Domain.empty() domain.session_config = SessionConfig( session_expiration_time=1 / 60, carry_over_slots=True ) monkeypatch.setattr(rasa_app.app.agent, "domain", domain) # conversation contains one session that has expired now = time.time() conversation_events = [ ActionExecuted(ACTION_SESSION_START_NAME, timestamp=now - 10), SessionStarted(timestamp=now - 9), UserUttered("hi", {"name": "greet"}, timestamp=now - 8), ActionExecuted("utter_greet", timestamp=now - 7), ] tracker = DialogueStateTracker.from_events(conversation_id, conversation_events) # the conversation session has expired assert rasa_app.app.agent.create_processor()._has_session_expired(tracker) tracker_store = InMemoryTrackerStore(domain) tracker_store.save(tracker) monkeypatch.setattr(rasa_app.app.agent, "tracker_store", tracker_store) _, response = await rasa_app.get(f"/conversations/{conversation_id}/story") assert response.status == 200 # expected story is returned assert ( response.content.decode().strip() == """version: "2.0" stories: - story: some-conversation-ID steps: - intent: greet user: |- hi - action: utter_greet""" ) # the tracker has the same number of events as were initially added assert len(tracker.events) == len(conversation_events) # the last event is still the same as before assert tracker.events[-1].timestamp == conversation_events[-1].timestamp
def test_merge_session_config_if_first_is_not_default(): yaml1 = """ session_config: session_expiration_time: 20 carry_over_slots: true""" yaml2 = """ session_config: session_expiration_time: 40 carry_over_slots: true """ domain1 = Domain.from_yaml(yaml1) domain2 = Domain.from_yaml(yaml2) merged = domain1.merge(domain2) assert merged.session_config == SessionConfig(20, True) merged = domain1.merge(domain2, override=True) assert merged.session_config == SessionConfig(40, True)
async def test_has_session_expired( event_to_apply: Optional[Event], session_expiration_time_in_minutes: float, has_expired: bool, default_processor: MessageProcessor, ): sender_id = uuid.uuid4().hex default_processor.domain.session_config = SessionConfig( session_expiration_time_in_minutes, True) # create new tracker without events tracker = default_processor.tracker_store.get_or_create_tracker(sender_id) tracker.events.clear() # apply desired event if event_to_apply: tracker.update(event_to_apply) # noinspection PyProtectedMember assert default_processor._has_session_expired(tracker) == has_expired
async def test_fetch_tracker_with_initial_session_does_not_update_session( default_channel: CollectingOutputChannel, default_processor: MessageProcessor, monkeypatch: MonkeyPatch, ): conversation_id = uuid.uuid4().hex # the domain has a session expiration time of one second monkeypatch.setattr( default_processor.tracker_store.domain, "session_config", SessionConfig(carry_over_slots=True, session_expiration_time=1 / 60), ) now = time.time() # the tracker initially contains events initial_events = [ ActionExecuted(ACTION_SESSION_START_NAME, timestamp=now - 10), SessionStarted(timestamp=now - 9), ActionExecuted(ACTION_LISTEN_NAME, timestamp=now - 8), UserUttered( "/greet", {INTENT_NAME_KEY: "greet", "confidence": 1.0}, timestamp=now - 7 ), ] tracker = DialogueStateTracker.from_events(conversation_id, initial_events) default_processor.tracker_store.save(tracker) tracker = await default_processor.fetch_tracker_with_initial_session( conversation_id, default_channel ) # the conversation session has expired, but calling # `fetch_tracker_with_initial_session()` did not update it assert default_processor._has_session_expired(tracker) assert [event.as_dict() for event in tracker.events] == [ event.as_dict() for event in initial_events ]
default_channel: CollectingOutputChannel, template_nlg: TemplatedNaturalLanguageGenerator, template_sender_tracker: DialogueStateTracker, default_domain: Domain, ): events = await ActionSessionStart().run(default_channel, template_nlg, template_sender_tracker, default_domain) assert events == [SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME)] @pytest.mark.parametrize( "session_config, expected_events", [ ( SessionConfig(123, True), [ SessionStarted(), SlotSet("my_slot", "value"), SlotSet("another-slot", "value2"), ActionExecuted(action_name=ACTION_LISTEN_NAME), ], ), ( SessionConfig(123, False), [SessionStarted(), ActionExecuted(action_name=ACTION_LISTEN_NAME)], ), ], ) async def test_action_session_start_with_slots(
def reset_conversation_state(agent: Agent) -> Agent: # Clean tracker store after each test so tests don't affect each other agent.tracker_store = InMemoryTrackerStore(agent.domain) agent.domain.session_config = SessionConfig.default() return agent
def test_are_sessions_enabled(session_config: SessionConfig, enabled: bool): assert session_config.are_sessions_enabled() == enabled
def test_domain_as_dict_with_session_config(): session_config = SessionConfig(123, False) domain = Domain.empty() domain.session_config = session_config serialized = domain.as_dict() deserialized = Domain.from_dict(serialized) assert deserialized.session_config == session_config @pytest.mark.parametrize( "session_config, enabled", [ (SessionConfig(0, True), False), (SessionConfig(1, True), True), (SessionConfig(-1, False), False), ], ) def test_are_sessions_enabled(session_config: SessionConfig, enabled: bool): assert session_config.are_sessions_enabled() == enabled def test_domain_from_dict_does_not_change_input(): input_before = { "intents": [ { "greet": { USE_ENTITIES_KEY: ["name"] }