Example #1
0
def marker_sqlite_tracker(tmp_path: Path) -> Tuple[SQLTrackerStore, Text]:
    domain = Domain.empty()
    db_path = str(tmp_path / "rasa.db")
    tracker_store = SQLTrackerStore(dialect="sqlite", db=db_path)
    for i in range(5):
        tracker = DialogueStateTracker(str(i), None)
        tracker.update_with_events([SlotSet(str(j), "slot") for j in range(5)],
                                   domain)
        tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
        tracker.update(UserUttered("hello"))
        tracker.update_with_events(
            [SlotSet(str(5 + j), "slot") for j in range(5)], domain)
        tracker_store.save(tracker)

    return tracker_store, db_path
Example #2
0
def test_load_sessions(tmp_path):
    """Tests loading a tracker with multiple sessions."""
    domain = Domain.empty()
    store = SQLTrackerStore(domain, db=os.path.join(tmp_path, "temp.db"))
    tracker = DialogueStateTracker("test123", None)
    tracker.update_with_events(
        [
            UserUttered("0"),
            UserUttered("1"),
            SessionStarted(),
            UserUttered("2"),
            UserUttered("3"),
        ],
        domain,
    )
    store.save(tracker)

    loader = MarkerTrackerLoader(store, STRATEGY_ALL)
    result = list(loader.load())
    assert len(result) == 1  # contains only one tracker
    assert len(result[0].events) == len(tracker.events)
Example #3
0
def test_fetch_events_within_time_range_with_session_events():
    conversation_id = "test_fetch_events_within_time_range_with_sessions"

    tracker_store = SQLTrackerStore(
        dialect="sqlite", db=f"{uuid.uuid4().hex}.db", domain=Domain.empty()
    )

    events = [
        random_user_uttered_event(1),
        SessionStarted(2),
        ActionExecuted(timestamp=3, action_name=ACTION_SESSION_START_NAME),
        random_user_uttered_event(4),
    ]
    tracker = DialogueStateTracker.from_events(conversation_id, evts=events)
    tracker_store.save(tracker)

    exporter = MockExporter(tracker_store=tracker_store)

    # noinspection PyProtectedMember
    fetched_events = exporter._fetch_events_within_time_range()

    assert len(fetched_events) == len(events)