def test_load_first_n(marker_trackerstore: TrackerStore): """Tests loading trackers using 'first_n' strategy.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_FIRST_N, 3) result = list(loader.load()) assert len(result) == 3 for item in result: assert marker_trackerstore.exists(item.sender_id)
def test_load_all(marker_trackerstore: TrackerStore): """Tests loading trackers using 'all' strategy.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_ALL) result = list(loader.load()) assert len(result) == len(list(marker_trackerstore.keys())) for item in result: assert marker_trackerstore.exists(item.sender_id)
def test_load_sample(marker_trackerstore: TrackerStore): """Tests loading trackers using 'sample' strategy.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N, 3) result = list(loader.load()) assert len(result) == 3 senders = set() for item in result: assert marker_trackerstore.exists(item.sender_id) assert item.sender_id not in senders senders.add(item.sender_id)
def test_load_sample_with_seed(marker_trackerstore: TrackerStore): """Tests loading trackers using 'sample' strategy with seed set.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N, 3, seed=3) result = list(loader.load()) expected_ids = ["1", "4", "3"] assert len(result) == 3 for item, expected in zip(result, expected_ids): assert item.sender_id == expected assert marker_trackerstore.exists(item.sender_id)
def _create_tracker_loader( endpoint_config: Text, strategy: Text, domain: Domain, count: Optional[int], seed: Optional[int], ) -> MarkerTrackerLoader: """Create a tracker loader against the configured tracker store. Args: endpoint_config: Path to the endpoint configuration defining the tracker store to use. strategy: Strategy to use when selecting trackers to extract from. domain: The domain to use when connecting to the tracker store. count: (Optional) Number of trackers to extract from (for any strategy except 'all'). seed: (Optional) The seed to initialise the random number generator for use with the 'sample_n' strategy. Returns: A MarkerTrackerLoader object configured with the specified strategy against the configured tracker store. """ endpoints = AvailableEndpoints.read_endpoints(endpoint_config) tracker_store = TrackerStore.create(endpoints.tracker_store, domain=domain) return MarkerTrackerLoader(tracker_store, strategy, count, seed)
async def test_markers_cli_results_save_correctly(tmp_path: Path): domain = Domain.empty() store = InMemoryTrackerStore(domain) for i in range(5): tracker = DialogueStateTracker(str(i), None) tracker.update_with_events([SlotSet(str(j), "slot") for j in range(5)], domain) tracker.update(ActionExecuted(ACTION_SESSION_START_NAME)) tracker.update(UserUttered("hello")) tracker.update_with_events( [SlotSet(str(5 + j), "slot") for j in range(5)], domain) await store.save(tracker) tracker_loader = MarkerTrackerLoader(store, "all") results_path = tmp_path / "results.csv" markers = OrMarker(markers=[ SlotSetMarker("2", name="marker1"), SlotSetMarker("7", name="marker2") ]) await markers.evaluate_trackers(tracker_loader.load(), results_path) with open(results_path, "r") as results: result_reader = csv.DictReader(results) senders = set() for row in result_reader: senders.add(row["sender_id"]) if row["marker"] == "marker1": assert row["session_idx"] == "0" assert int(row["event_idx"]) >= 2 assert row["num_preceding_user_turns"] == "0" if row["marker"] == "marker2": assert row["session_idx"] == "1" assert int(row["event_idx"]) >= 3 assert row["num_preceding_user_turns"] == "1" assert len(senders) == 5
def test_load_sessions(tmp_path): """Tests loading a tracker with multiple sessions.""" domain = Domain.empty() store = SQLTrackerStore(domain, db=os.path.join(tmp_path, "temp.db")) tracker = DialogueStateTracker("test123", None) tracker.update_with_events( [ UserUttered("0"), UserUttered("1"), SessionStarted(), UserUttered("2"), UserUttered("3"), ], domain, ) store.save(tracker) loader = MarkerTrackerLoader(store, STRATEGY_ALL) result = list(loader.load()) assert len(result) == 1 # contains only one tracker assert len(result[0].events) == len(tracker.events)
def test_warn_count_exceeds_store(marker_trackerstore: TrackerStore): """Tests a warning is thrown when 'count' is larger than the number of trackers.""" loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N, 6) with pytest.warns(UserWarning): # Need to force the generator to evaluate to produce the warning list(loader.load())
def test_warn_count_all_unnecessary(marker_trackerstore: TrackerStore): """Tests a warning is thrown when 'count' is set for strategy 'all'.""" with pytest.warns(UserWarning): MarkerTrackerLoader(marker_trackerstore, STRATEGY_ALL, 3)
def test_warn_seed_unnecessary(marker_trackerstore: TrackerStore): """Tests a warning is thrown when 'seed' is set for non-'sample' strategies.""" with pytest.warns(UserWarning): MarkerTrackerLoader(marker_trackerstore, STRATEGY_FIRST_N, 3, seed=5)
def test_exception_negative_count(marker_trackerstore: TrackerStore): """Tests an exception is thrown when an invalid count is given.""" with pytest.raises(RasaException): MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N, -1)
def test_exception_no_count(marker_trackerstore: TrackerStore): """Tests an exception is thrown when no count is given for non-'all' strategies.""" with pytest.raises(RasaException): MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N)
def test_exception_invalid_strategy(marker_trackerstore: TrackerStore): """Tests an exception is thrown when an invalid strategy is used.""" with pytest.raises(RasaException): MarkerTrackerLoader(marker_trackerstore, "summon")