def test_load_first_n(marker_trackerstore: TrackerStore): """Tests loading trackers using 'first_n' strategy.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_FIRST_N, 3) result = list(loader.load()) assert len(result) == 3 for item in result: assert marker_trackerstore.exists(item.sender_id)
def test_load_all(marker_trackerstore: TrackerStore): """Tests loading trackers using 'all' strategy.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_ALL) result = list(loader.load()) assert len(result) == len(list(marker_trackerstore.keys())) for item in result: assert marker_trackerstore.exists(item.sender_id)
def test_load_sample(marker_trackerstore: TrackerStore): """Tests loading trackers using 'sample' strategy.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N, 3) result = list(loader.load()) assert len(result) == 3 senders = set() for item in result: assert marker_trackerstore.exists(item.sender_id) assert item.sender_id not in senders senders.add(item.sender_id)
def test_load_sample_with_seed(marker_trackerstore: TrackerStore): """Tests loading trackers using 'sample' strategy with seed set.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N, 3, seed=3) result = list(loader.load()) expected_ids = ["1", "4", "3"] assert len(result) == 3 for item, expected in zip(result, expected_ids): assert item.sender_id == expected assert marker_trackerstore.exists(item.sender_id)
async def test_markers_cli_results_save_correctly(tmp_path: Path): domain = Domain.empty() store = InMemoryTrackerStore(domain) for i in range(5): tracker = DialogueStateTracker(str(i), None) tracker.update_with_events([SlotSet(str(j), "slot") for j in range(5)], domain) tracker.update(ActionExecuted(ACTION_SESSION_START_NAME)) tracker.update(UserUttered("hello")) tracker.update_with_events( [SlotSet(str(5 + j), "slot") for j in range(5)], domain) await store.save(tracker) tracker_loader = MarkerTrackerLoader(store, "all") results_path = tmp_path / "results.csv" markers = OrMarker(markers=[ SlotSetMarker("2", name="marker1"), SlotSetMarker("7", name="marker2") ]) await markers.evaluate_trackers(tracker_loader.load(), results_path) with open(results_path, "r") as results: result_reader = csv.DictReader(results) senders = set() for row in result_reader: senders.add(row["sender_id"]) if row["marker"] == "marker1": assert row["session_idx"] == "0" assert int(row["event_idx"]) >= 2 assert row["num_preceding_user_turns"] == "0" if row["marker"] == "marker2": assert row["session_idx"] == "1" assert int(row["event_idx"]) >= 3 assert row["num_preceding_user_turns"] == "1" assert len(senders) == 5
def test_load_sessions(tmp_path): """Tests loading a tracker with multiple sessions.""" domain = Domain.empty() store = SQLTrackerStore(domain, db=os.path.join(tmp_path, "temp.db")) tracker = DialogueStateTracker("test123", None) tracker.update_with_events( [ UserUttered("0"), UserUttered("1"), SessionStarted(), UserUttered("2"), UserUttered("3"), ], domain, ) store.save(tracker) loader = MarkerTrackerLoader(store, STRATEGY_ALL) result = list(loader.load()) assert len(result) == 1 # contains only one tracker assert len(result[0].events) == len(tracker.events)
def test_warn_count_exceeds_store(marker_trackerstore: TrackerStore): """Tests a warning is thrown when 'count' is larger than the number of trackers.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N, 6) with pytest.warns(UserWarning): # Need to force the generator to evaluate to produce the warning list(loader.load())