def test_tracker_store_retrieve_without_session_started_events( tracker_store_type: Type[TrackerStore], tracker_store_kwargs: Dict, default_domain: Domain, ): tracker_store = tracker_store_type(default_domain, **tracker_store_kwargs) # Create tracker with a SessionStarted event events = [ UserUttered("Hola", {"name": "greet"}), BotUttered("Hi"), UserUttered("Ciao", {"name": "greet"}), BotUttered("Hi2"), ] sender_id = "test_sql_tracker_store_retrieve_without_session_started_events" tracker = DialogueStateTracker.from_events(sender_id, events) tracker_store.save(tracker) # Save other tracker to ensure that we don't run into problems with other senders other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()]) tracker_store.save(other_tracker) tracker = tracker_store.retrieve(sender_id) assert len(tracker.events) == 4 assert all(event == tracker.events[i] for i, event in enumerate(events))
def test_tracker_store_retrieve_with_session_started_events( tracker_store_type: Type[TrackerStore], tracker_store_kwargs: Dict, domain: Domain, ): tracker_store = tracker_store_type(domain, **tracker_store_kwargs) events = [ UserUttered("Hola", {"name": "greet"}, timestamp=1), BotUttered("Hi", timestamp=2), SessionStarted(timestamp=3), UserUttered("Ciao", {"name": "greet"}, timestamp=4), ] sender_id = "test_sql_tracker_store_with_session_events" tracker = DialogueStateTracker.from_events(sender_id, events) tracker_store.save(tracker) # Save other tracker to ensure that we don't run into problems with other senders other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()]) tracker_store.save(other_tracker) # Retrieve tracker with events since latest SessionStarted tracker = tracker_store.retrieve(sender_id) assert len(tracker.events) == 2 assert all( (event == tracker.events[i] for i, event in enumerate(events[2:])))
async def prepare_token_serialisation( tracker_store: TrackerStore, response_selector_agent: Agent, sender_id: Text ): text = "Good morning" tokenizer = WhitespaceTokenizer(WhitespaceTokenizer.get_default_config()) tokens = tokenizer.tokenize(Message(data={"text": text}), "text") indices = [[t.start, t.end] for t in tokens] tracker = tracker_store.get_or_create_tracker(sender_id=sender_id) parse_data = await response_selector_agent.parse_message(text) event = UserUttered( "Good morning", parse_data.get("intent"), parse_data.get("entities", []), parse_data, ) tracker.update(event) tracker_store.save(tracker) retrieved_tracker = tracker_store.retrieve(sender_id=sender_id) event = retrieved_tracker.get_last_event_for(event_type=UserUttered) event_tokens = event.as_dict().get("parse_data").get("text_tokens") assert event_tokens == indices
def test_fail_safe_tracker_store_with_retrieve_error(): mocked_tracker_store = Mock() mocked_tracker_store.retrieve = Mock(side_effect=Exception()) fallback_tracker_store = Mock() on_error_callback = Mock() tracker_store = FailSafeTrackerStore( mocked_tracker_store, on_error_callback, fallback_tracker_store ) assert tracker_store.retrieve("sender_id") is None on_error_callback.assert_called_once()
def _saved_tracker_with_multiple_session_starts( tracker_store: TrackerStore, sender_id: Text) -> DialogueStateTracker: tracker = DialogueStateTracker.from_events( sender_id, [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), UserUttered("hi"), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ], ) tracker_store.save(tracker) return tracker_store.retrieve(sender_id)
def test_tracker_store_retrieve_with_events_from_previous_sessions( tracker_store_type: Type[TrackerStore], tracker_store_kwargs: Dict): tracker_store = tracker_store_type(Domain.empty(), **tracker_store_kwargs) tracker_store.load_events_from_previous_conversation_sessions = True conversation_id = uuid.uuid4().hex tracker = DialogueStateTracker.from_events( conversation_id, [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), UserUttered("hi"), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ], ) tracker_store.save(tracker) actual = tracker_store.retrieve(conversation_id) assert len(actual.events) == len(tracker.events)
def test_fail_safe_tracker_store_if_no_errors(): mocked_tracker_store = Mock() tracker_store = FailSafeTrackerStore(mocked_tracker_store, None) # test save mocked_tracker_store.save = Mock() tracker_store.save(None) mocked_tracker_store.save.assert_called_once() # test retrieve expected = [1] mocked_tracker_store.retrieve = Mock(return_value=expected) sender_id = "10" assert tracker_store.retrieve(sender_id) == expected mocked_tracker_store.retrieve.assert_called_once_with(sender_id) # test keys expected = ["sender 1", "sender 2"] mocked_tracker_store.keys = Mock(return_value=expected) assert tracker_store.keys() == expected mocked_tracker_store.keys.assert_called_once()
def test_tracker_store_deprecated_session_retrieval_kwarg(): tracker_store = SQLTrackerStore( Domain.empty(), retrieve_events_from_previous_conversation_sessions=True) conversation_id = uuid.uuid4().hex tracker = DialogueStateTracker.from_events( conversation_id, [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), UserUttered("hi"), ], ) mocked_retrieve_full_tracker = Mock() tracker_store.retrieve_full_tracker = mocked_retrieve_full_tracker tracker_store.save(tracker) _ = tracker_store.retrieve(conversation_id) mocked_retrieve_full_tracker.assert_called_once()