예제 #1
0
async def test_update_tracker_session_with_metadata(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex
    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    # patch `_has_session_expired()` so the `_update_tracker_session()` call actually
    # does something
    monkeypatch.setattr(default_processor, "_has_session_expired",
                        lambda _: True)

    metadata = {"metadataTestKey": "metadataTestValue"}

    await default_processor._update_tracker_session(tracker, default_channel,
                                                    metadata)

    # the save is not called in _update_tracker_session()
    default_processor._save_tracker(tracker)

    # inspect tracker events and make sure SessionStarted event is present
    # and has metadata.
    tracker = default_processor.tracker_store.retrieve(sender_id)
    assert tracker.events.count(SessionStarted()) == 1

    session_started_event_idx = tracker.events.index(SessionStarted())
    session_started_event_metadata = tracker.events[
        session_started_event_idx].metadata

    assert session_started_event_metadata == metadata
예제 #2
0
def test_tracker_store_retrieve_with_session_started_events(
    tracker_store_type: Type[TrackerStore],
    tracker_store_kwargs: Dict,
    default_domain: Domain,
):
    tracker_store = tracker_store_type(default_domain, **tracker_store_kwargs)
    events = [
        UserUttered("Hola", {"name": "greet"}, timestamp=1),
        BotUttered("Hi", timestamp=2),
        SessionStarted(timestamp=3),
        UserUttered("Ciao", {"name": "greet"}, timestamp=4),
    ]
    sender_id = "test_sql_tracker_store_with_session_events"
    tracker = DialogueStateTracker.from_events(sender_id, events)
    tracker_store.save(tracker)

    # Save other tracker to ensure that we don't run into problems with other senders
    other_tracker = DialogueStateTracker.from_events("other-sender",
                                                     [SessionStarted()])
    tracker_store.save(other_tracker)

    # Retrieve tracker with events since latest SessionStarted
    tracker = tracker_store.retrieve(sender_id)

    assert len(tracker.events) == 2
    assert all(
        (event == tracker.events[i] for i, event in enumerate(events[2:])))
예제 #3
0
async def test_restart_triggers_session_start(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    # The rule policy is trained and used so as to allow the default action ActionRestart to be predicted
    rule_policy = RulePolicy()
    rule_policy.train([], default_processor.domain, RegexInterpreter())
    monkeypatch.setattr(
        default_processor.policy_ensemble,
        "policies",
        [rule_policy, *default_processor.policy_ensemble.policies],
    )

    sender_id = uuid.uuid4().hex

    entity = "name"
    slot_1 = {entity: "name1"}
    await default_processor.handle_message(
        UserMessage(f"/greet{json.dumps(slot_1)}", default_channel, sender_id)
    )

    assert default_channel.latest_output() == {
        "recipient_id": sender_id,
        "text": "hey there name1!",
    }

    # This restarts the chat
    await default_processor.handle_message(
        UserMessage("/restart", default_channel, sender_id)
    )

    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    expected = [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(
            f"/greet{json.dumps(slot_1)}",
            {INTENT_NAME_KEY: "greet", "confidence": 1.0},
            [{"entity": entity, "start": 6, "end": 23, "value": "name1"}],
        ),
        SlotSet(entity, slot_1[entity]),
        DefinePrevUserUtteredFeaturization(use_text_for_featurization=False),
        ActionExecuted("utter_greet"),
        BotUttered("hey there name1!", metadata={"template_name": "utter_greet"}),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("/restart", {INTENT_NAME_KEY: "restart", "confidence": 1.0}),
        DefinePrevUserUtteredFeaturization(use_text_for_featurization=False),
        ActionExecuted(ACTION_RESTART_NAME),
        Restarted(),
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        # No previous slot is set due to restart.
        ActionExecuted(ACTION_LISTEN_NAME),
    ]
    for actual, expected in zip(tracker.events, expected):
        assert actual == expected
예제 #4
0
def _saved_tracker_with_multiple_session_starts(
        tracker_store: TrackerStore, sender_id: Text) -> DialogueStateTracker:
    tracker = DialogueStateTracker.from_events(
        sender_id,
        [
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
            UserUttered("hi"),
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
        ],
    )

    tracker_store.save(tracker)
    return tracker_store.retrieve(sender_id)
예제 #5
0
async def test_update_tracker_session_with_metadata(
    default_processor: MessageProcessor, monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex
    metadata = {"metadataTestKey": "metadataTestValue"}
    message = UserMessage(
        text="hi",
        output_channel=CollectingOutputChannel(),
        sender_id=sender_id,
        metadata=metadata,
    )
    await default_processor.handle_message(message)

    tracker = default_processor.tracker_store.retrieve(sender_id)
    events = list(tracker.events)

    assert events[0] == SlotSet(SESSION_START_METADATA_SLOT, metadata)
    assert tracker.slots[SESSION_START_METADATA_SLOT].value == metadata

    assert events[1] == ActionExecuted(ACTION_SESSION_START_NAME)
    assert events[2] == SessionStarted()
    assert events[2].metadata == metadata
    assert events[3] == SlotSet(SESSION_START_METADATA_SLOT, metadata)
    assert events[4] == ActionExecuted(ACTION_LISTEN_NAME)
    assert isinstance(events[5], UserUttered)
예제 #6
0
async def test_update_tracker_session(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex
    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    # patch `_has_session_expired()` so the `_update_tracker_session()` call actually
    # does something
    monkeypatch.setattr(default_processor, "_has_session_expired",
                        lambda _: True)

    await default_processor._update_tracker_session(tracker, default_channel)

    # the save is not called in _update_tracker_session()
    default_processor._save_tracker(tracker)

    # inspect tracker and make sure all events are present
    tracker = default_processor.tracker_store.retrieve(sender_id)

    assert list(tracker.events) == [
        ActionExecuted(ACTION_LISTEN_NAME),
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
    ]
예제 #7
0
def test_tracker_store_retrieve_without_session_started_events(
    tracker_store_type: Type[TrackerStore],
    tracker_store_kwargs: Dict,
    domain,
):
    tracker_store = tracker_store_type(domain, **tracker_store_kwargs)

    # Create tracker with a SessionStarted event
    events = [
        UserUttered("Hola", {"name": "greet"}),
        BotUttered("Hi"),
        UserUttered("Ciao", {"name": "greet"}),
        BotUttered("Hi2"),
    ]

    sender_id = "test_sql_tracker_store_retrieve_without_session_started_events"
    tracker = DialogueStateTracker.from_events(sender_id, events)
    tracker_store.save(tracker)

    # Save other tracker to ensure that we don't run into problems with other senders
    other_tracker = DialogueStateTracker.from_events("other-sender",
                                                     [SessionStarted()])
    tracker_store.save(other_tracker)

    tracker = tracker_store.retrieve(sender_id)

    assert len(tracker.events) == 4
    assert all(event == tracker.events[i] for i, event in enumerate(events))
예제 #8
0
def test_restart_event(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    intent = {"name": "greet", PREDICTED_CONFIDENCE_KEY: 1.0}
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
    tracker.update(UserUttered("/greet", intent, []))
    tracker.update(ActionExecuted("my_action"))
    tracker.update(ActionExecuted(ACTION_LISTEN_NAME))

    assert len(tracker.events) == 4
    assert tracker.latest_message.text == "/greet"
    assert len(list(tracker.generate_all_prior_trackers())) == 4

    tracker.update(Restarted())

    assert len(tracker.events) == 5
    assert tracker.followup_action == ACTION_SESSION_START_NAME

    tracker.update(SessionStarted())

    assert tracker.followup_action == ACTION_LISTEN_NAME
    assert tracker.latest_message.text is None
    assert len(list(tracker.generate_all_prior_trackers())) == 1

    dialogue = tracker.as_dialogue()

    recovered = DialogueStateTracker("default", default_domain.slots)
    recovered.recreate_from_dialogue(dialogue)

    assert recovered.current_state() == tracker.current_state()
    assert len(recovered.events) == 6
    assert recovered.latest_message.text is None
    assert len(list(recovered.generate_all_prior_trackers())) == 1
예제 #9
0
def test_session_start_is_not_serialised(domain: Domain):
    tracker = DialogueStateTracker("default", domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    # add SlotSet event
    tracker.update(SlotSet("slot", "value"))

    # add the two SessionStarted events and a user event
    tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
    tracker.update(SessionStarted())
    tracker.update(
        UserUttered("say something", intent={INTENT_NAME_KEY: "some_intent"}))
    tracker.update(DefinePrevUserUtteredFeaturization(False))

    YAMLStoryWriter().dumps(
        Story.from_events(tracker.events, "some-story01").story_steps)

    expected = """version: "2.0"
stories:
- story: some-story01
  steps:
  - slot_was_set:
    - slot: value
  - intent: some_intent
"""

    actual = YAMLStoryWriter().dumps(
        Story.from_events(tracker.events, "some-story01").story_steps)
    assert actual == expected
예제 #10
0
async def test_agent_handle_message_only_core(trained_core_model: Text):
    agent = await load_agent(model_path=trained_core_model)
    model_id = agent.model_id
    sender_id = uuid.uuid4().hex
    message = UserMessage("/greet", sender_id=sender_id)
    await agent.handle_message(message)
    tracker = agent.tracker_store.get_or_create_tracker(sender_id)
    expected_events = with_model_ids(
        [
            ActionExecuted(action_name="action_session_start"),
            SessionStarted(),
            ActionExecuted(action_name="action_listen"),
            UserUttered(text="/greet", intent={"name": "greet"},),
            DefinePrevUserUtteredFeaturization(False),
            ActionExecuted(action_name="utter_greet"),
            BotUttered(
                "hey there None!",
                {
                    "elements": None,
                    "quick_replies": None,
                    "buttons": None,
                    "attachment": None,
                    "image": None,
                    "custom": None,
                },
                {"utter_action": "utter_greet"},
            ),
            ActionExecuted(action_name="action_listen"),
        ],
        model_id,
    )
    assert len(tracker.events) == len(expected_events)
    for e1, e2 in zip(tracker.events, expected_events):
        assert e1 == e2
예제 #11
0
async def test_fetch_events_within_time_range_with_session_events(
        tmp_path: Path):
    conversation_id = "test_fetch_events_within_time_range_with_sessions"

    tracker_store = SQLTrackerStore(
        dialect="sqlite",
        db=str(tmp_path / f"{uuid.uuid4().hex}.db"),
        domain=Domain.empty(),
    )

    events = [
        random_user_uttered_event(1),
        SessionStarted(2),
        ActionExecuted(timestamp=3, action_name=ACTION_SESSION_START_NAME),
        random_user_uttered_event(4),
    ]
    tracker = DialogueStateTracker.from_events(conversation_id, evts=events)
    await tracker_store.save(tracker)

    exporter = MockExporter(tracker_store=tracker_store)

    # noinspection PyProtectedMember
    fetched_events = await exporter._fetch_events_within_time_range()

    assert len(fetched_events) == len(events)
예제 #12
0
async def test_policy_events_are_applied_to_tracker(
        default_processor: MessageProcessor, monkeypatch: MonkeyPatch):
    expected_action = ACTION_LISTEN_NAME
    policy_events = [LoopInterrupted(True)]
    conversation_id = "test_policy_events_are_applied_to_tracker"
    user_message = "/greet"

    expected_events = [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(user_message, intent={"name": "greet"}),
        *policy_events,
    ]

    class ConstantEnsemble(PolicyEnsemble):
        def probabilities_using_best_policy(
            self,
            tracker: DialogueStateTracker,
            domain: Domain,
            interpreter: NaturalLanguageInterpreter,
            **kwargs: Any,
        ) -> PolicyPrediction:
            prediction = PolicyPrediction.for_action_name(
                default_processor.domain, expected_action, "some policy")
            prediction.events = policy_events

            return prediction

    monkeypatch.setattr(default_processor, "policy_ensemble",
                        ConstantEnsemble([]))

    action_received_events = False

    async def mocked_run(
        self,
        output_channel: "OutputChannel",
        nlg: "NaturalLanguageGenerator",
        tracker: "DialogueStateTracker",
        domain: "Domain",
    ) -> List[Event]:
        # The action already has access to the policy events
        nonlocal action_received_events
        action_received_events = list(tracker.events) == expected_events
        return []

    monkeypatch.setattr(ActionListen, ActionListen.run.__name__, mocked_run)

    await default_processor.handle_message(
        UserMessage(user_message, sender_id=conversation_id))

    assert action_received_events

    tracker = default_processor.get_tracker(conversation_id)
    # The action was logged on the tracker as well
    expected_events.append(ActionExecuted(ACTION_LISTEN_NAME))

    for event, expected in zip(tracker.events, expected_events):
        assert event == expected
예제 #13
0
async def test_action_session_start_without_slots(
    default_channel: CollectingOutputChannel,
    template_nlg: TemplatedNaturalLanguageGenerator,
    template_sender_tracker: DialogueStateTracker,
    domain: Domain,
):
    events = await ActionSessionStart().run(default_channel, template_nlg,
                                            template_sender_tracker, domain)
    assert events == [SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME)]
예제 #14
0
def test_session_start(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    # add a SessionStarted event
    tracker.update(SessionStarted())

    # tracker has one event
    assert len(tracker.events) == 1
예제 #15
0
파일: test_server.py 프로젝트: attgua/Geco
async def test_get_story_does_not_update_conversation_session(
    rasa_app: SanicASGITestClient, monkeypatch: MonkeyPatch
):
    conversation_id = "some-conversation-ID"

    # domain with short session expiration time of one second
    domain = Domain.empty()
    domain.session_config = SessionConfig(
        session_expiration_time=1 / 60, carry_over_slots=True
    )

    monkeypatch.setattr(rasa_app.app.agent, "domain", domain)

    # conversation contains one session that has expired
    now = time.time()
    conversation_events = [
        ActionExecuted(ACTION_SESSION_START_NAME, timestamp=now - 10),
        SessionStarted(timestamp=now - 9),
        UserUttered("hi", {"name": "greet"}, timestamp=now - 8),
        ActionExecuted("utter_greet", timestamp=now - 7),
    ]

    tracker = DialogueStateTracker.from_events(conversation_id, conversation_events)

    # the conversation session has expired
    assert rasa_app.app.agent.create_processor()._has_session_expired(tracker)

    tracker_store = InMemoryTrackerStore(domain)

    tracker_store.save(tracker)

    monkeypatch.setattr(rasa_app.app.agent, "tracker_store", tracker_store)

    _, response = await rasa_app.get(f"/conversations/{conversation_id}/story")

    assert response.status == 200

    # expected story is returned
    assert (
        response.content.decode().strip()
        == """version: "2.0"
stories:
- story: some-conversation-ID
  steps:
  - intent: greet
    user: |-
      hi
  - action: utter_greet"""
    )

    # the tracker has the same number of events as were initially added
    assert len(tracker.events) == len(conversation_events)

    # the last event is still the same as before
    assert tracker.events[-1].timestamp == conversation_events[-1].timestamp
예제 #16
0
def test_split_sessions(tmp_path):
    """Tests loading a tracker with multiple sessions."""

    events = [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        UserUttered(intent={"name": "this-intent"}),
    ]
    sessions = Marker._split_sessions(events)
    assert len(sessions) == 1
    assert len(sessions[0][0]) == len(events)
예제 #17
0
def test_tracker_store_retrieve_with_events_from_previous_sessions(
        tracker_store_type: Type[TrackerStore], tracker_store_kwargs: Dict):
    tracker_store = tracker_store_type(Domain.empty(), **tracker_store_kwargs)

    conversation_id = uuid.uuid4().hex
    tracker = DialogueStateTracker.from_events(
        conversation_id,
        [
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
            UserUttered("hi"),
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
        ],
    )
    tracker_store.save(tracker)

    actual = tracker_store.retrieve_full_tracker(conversation_id)

    assert len(actual.events) == len(tracker.events)
예제 #18
0
async def test_get_tracker_with_session_start(
        default_channel: CollectingOutputChannel,
        default_processor: MessageProcessor):
    sender_id = uuid.uuid4().hex
    tracker = await default_processor.get_tracker_with_session_start(
        sender_id, default_channel)

    # ensure session start sequence is present
    assert list(tracker.events) == [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
    ]
예제 #19
0
async def test_update_tracker_session_with_slots(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    sender_id = uuid.uuid4().hex
    tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

    # apply a user uttered and five slots
    user_event = UserUttered("some utterance")
    tracker.update(user_event)

    slot_set_events = [
        SlotSet(f"slot key {i}", f"test value {i}") for i in range(5)
    ]

    for event in slot_set_events:
        tracker.update(event)

    # patch `_has_session_expired()` so the `_update_tracker_session()` call actually
    # does something
    monkeypatch.setattr(default_processor, "_has_session_expired",
                        lambda _: True)

    await default_processor._update_tracker_session(tracker, default_channel)

    # the save is not called in _update_tracker_session()
    default_processor._save_tracker(tracker)

    # inspect tracker and make sure all events are present
    tracker = default_processor.tracker_store.retrieve(sender_id)
    events = list(tracker.events)

    # the first three events should be up to the user utterance
    assert events[:2] == [ActionExecuted(ACTION_LISTEN_NAME), user_event]

    # next come the five slots
    assert events[2:7] == slot_set_events

    # the next two events are the session start sequence
    assert events[7:9] == [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted()
    ]

    # the five slots should be reapplied
    assert events[9:14] == slot_set_events

    # finally an action listen, this should also be the last event
    assert events[14] == events[-1] == ActionExecuted(ACTION_LISTEN_NAME)
예제 #20
0
파일: action.py 프로젝트: karen-white/rasa
    async def run(
        self,
        output_channel: "OutputChannel",
        nlg: "NaturalLanguageGenerator",
        tracker: "DialogueStateTracker",
        domain: "Domain",
    ) -> List[Event]:
        _events = [SessionStarted(metadata=self.metadata)]

        if domain.session_config.carry_over_slots:
            _events.extend(self._slot_set_events_from_tracker(tracker))

        _events.append(ActionExecuted(ACTION_LISTEN_NAME))

        return _events
예제 #21
0
파일: action.py 프로젝트: ChenHuaYou/rasa
    async def run(
        self,
        output_channel: "OutputChannel",
        nlg: "NaturalLanguageGenerator",
        tracker: "DialogueStateTracker",
        domain: "Domain",
    ) -> List[Event]:
        """Runs action. Please see parent class for the full docstring."""
        _events: List[Event] = [SessionStarted()]

        if domain.session_config.carry_over_slots:
            _events.extend(self._slot_set_events_from_tracker(tracker))

        _events.append(ActionExecuted(ACTION_LISTEN_NAME))

        return _events
예제 #22
0
async def test_policy_events_not_applied_if_rejected(
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
    reject_fn: Callable[[], List[Event]],
):
    expected_action = ACTION_LISTEN_NAME
    expected_events = [LoopInterrupted(True)]
    conversation_id = "test_policy_events_are_applied_to_tracker"
    user_message = "/greet"

    class ConstantEnsemble(PolicyEnsemble):
        def probabilities_using_best_policy(
            self,
            tracker: DialogueStateTracker,
            domain: Domain,
            interpreter: NaturalLanguageInterpreter,
            **kwargs: Any,
        ) -> PolicyPrediction:
            prediction = PolicyPrediction.for_action_name(
                default_processor.domain, expected_action, "some policy"
            )
            prediction.events = expected_events

            return prediction

    monkeypatch.setattr(default_processor, "policy_ensemble", ConstantEnsemble([]))

    async def mocked_run(*args: Any, **kwargs: Any) -> List[Event]:
        return reject_fn()

    monkeypatch.setattr(ActionListen, ActionListen.run.__name__, mocked_run)

    await default_processor.handle_message(
        UserMessage(user_message, sender_id=conversation_id)
    )

    tracker = default_processor.get_tracker(conversation_id)
    expected_events = [
        ActionExecuted(ACTION_SESSION_START_NAME),
        SessionStarted(),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered(user_message, intent={"name": "greet"}),
        ActionExecutionRejected(ACTION_LISTEN_NAME),
    ]
    for event, expected in zip(tracker.events, expected_events):
        assert event == expected
예제 #23
0
async def test_agent_handle_message_only_nlu(trained_nlu_model: Text):
    agent = await load_agent(model_path=trained_nlu_model)
    model_id = agent.model_id
    sender_id = uuid.uuid4().hex
    message = UserMessage("hello", sender_id=sender_id)
    await agent.handle_message(message)
    tracker = agent.tracker_store.get_or_create_tracker(sender_id)
    expected_events = with_model_ids(
        [
            ActionExecuted(action_name="action_session_start"),
            SessionStarted(),
            ActionExecuted(action_name="action_listen"),
            UserUttered(text="hello", intent={"name": "greet"},),
        ],
        model_id,
    )
    assert len(tracker.events) == len(expected_events)
    for e1, e2 in zip(tracker.events, expected_events):
        assert e1 == e2
예제 #24
0
def test_load_sessions(tmp_path):
    """Tests loading a tracker with multiple sessions."""
    domain = Domain.empty()
    store = SQLTrackerStore(domain, db=os.path.join(tmp_path, "temp.db"))
    tracker = DialogueStateTracker("test123", None)
    tracker.update_with_events(
        [
            UserUttered("0"),
            UserUttered("1"),
            SessionStarted(),
            UserUttered("2"),
            UserUttered("3"),
        ],
        domain,
    )
    store.save(tracker)

    loader = MarkerTrackerLoader(store, STRATEGY_ALL)
    result = list(loader.load())
    assert len(result) == 1  # contains only one tracker
    assert len(result[0].events) == len(tracker.events)
예제 #25
0
async def test_fetch_tracker_with_initial_session_does_not_update_session(
    default_channel: CollectingOutputChannel,
    default_processor: MessageProcessor,
    monkeypatch: MonkeyPatch,
):
    conversation_id = uuid.uuid4().hex

    # the domain has a session expiration time of one second
    monkeypatch.setattr(
        default_processor.tracker_store.domain,
        "session_config",
        SessionConfig(carry_over_slots=True, session_expiration_time=1 / 60),
    )

    now = time.time()

    # the tracker initially contains events
    initial_events = [
        ActionExecuted(ACTION_SESSION_START_NAME, timestamp=now - 10),
        SessionStarted(timestamp=now - 9),
        ActionExecuted(ACTION_LISTEN_NAME, timestamp=now - 8),
        UserUttered(
            "/greet", {INTENT_NAME_KEY: "greet", "confidence": 1.0}, timestamp=now - 7
        ),
    ]

    tracker = DialogueStateTracker.from_events(conversation_id, initial_events)

    default_processor.tracker_store.save(tracker)

    tracker = await default_processor.fetch_tracker_with_initial_session(
        conversation_id, default_channel
    )

    # the conversation session has expired, but calling
    # `fetch_tracker_with_initial_session()` did not update it
    assert default_processor._has_session_expired(tracker)
    assert [event.as_dict() for event in tracker.events] == [
        event.as_dict() for event in initial_events
    ]
예제 #26
0
def test_session_start_is_not_serialised(default_domain: Domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    # add SlotSet event
    tracker.update(SlotSet("slot", "value"))

    # add the two SessionStarted events and a user event
    tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
    tracker.update(SessionStarted())
    tracker.update(UserUttered("say something"))

    # make sure session start is not serialised
    story = Story.from_events(tracker.events, "some-story01")

    expected = """## some-story01
    - slot{"slot": "value"}
* say something
"""

    assert story.as_story_string(flat=True) == expected
예제 #27
0
def test_tracker_store_deprecated_session_retrieval_kwarg():
    tracker_store = SQLTrackerStore(
        Domain.empty(),
        retrieve_events_from_previous_conversation_sessions=True)

    conversation_id = uuid.uuid4().hex
    tracker = DialogueStateTracker.from_events(
        conversation_id,
        [
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
            UserUttered("hi"),
        ],
    )

    mocked_retrieve_full_tracker = Mock()
    tracker_store.retrieve_full_tracker = mocked_retrieve_full_tracker

    tracker_store.save(tracker)

    _ = tracker_store.retrieve(conversation_id)

    mocked_retrieve_full_tracker.assert_called_once()
예제 #28
0
    template_sender_tracker: DialogueStateTracker,
    default_domain: Domain,
):
    events = await ActionSessionStart().run(default_channel, template_nlg,
                                            template_sender_tracker,
                                            default_domain)
    assert events == [SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME)]


@pytest.mark.parametrize(
    "session_config, expected_events",
    [
        (
            SessionConfig(123, True),
            [
                SessionStarted(),
                SlotSet("my_slot", "value"),
                SlotSet("another-slot", "value2"),
                ActionExecuted(action_name=ACTION_LISTEN_NAME),
            ],
        ),
        (
            SessionConfig(123, False),
            [SessionStarted(),
             ActionExecuted(action_name=ACTION_LISTEN_NAME)],
        ),
    ],
)
async def test_action_session_start_with_slots(
    default_channel: CollectingOutputChannel,
    template_nlg: TemplatedNaturalLanguageGenerator,
예제 #29
0
def test_json_parse_session_started():
    evt = {"event": "session_started"}
    assert Event.from_parameters(evt) == SessionStarted()
예제 #30
0
@pytest.mark.parametrize(
    "one_event,another_event",
    [
        (
            UserUttered("/greet", {"name": "greet", "confidence": 1.0}, []),
            UserUttered("/goodbye", {"name": "goodbye", "confidence": 1.0}, []),
        ),
        (SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")),
        (Restarted(), None),
        (AllSlotsReset(), None),
        (ConversationPaused(), None),
        (ConversationResumed(), None),
        (StoryExported(), None),
        (ActionReverted(), None),
        (UserUtteranceReverted(), None),
        (SessionStarted(), None),
        (ActionExecuted("my_action"), ActionExecuted("my_other_action")),
        (FollowupAction("my_action"), FollowupAction("my_other_action")),
        (
            BotUttered("my_text", {"my_data": 1}),
            BotUttered("my_other_test", {"my_other_data": 1}),
        ),
        (
            AgentUttered("my_text", "my_data"),
            AgentUttered("my_other_test", "my_other_data"),
        ),
        (
            ReminderScheduled("my_intent", datetime.now()),
            ReminderScheduled("my_other_intent", datetime.now()),
        ),
    ],