def test_get_tracker_store_from_endpoint_config_error_exit(tmp_path: Path): # write config without event broker to file endpoints_path = write_endpoint_config_to_yaml(tmp_path, {}) available_endpoints = rasa_core_utils.read_endpoints_from_path(endpoints_path) with pytest.raises(SystemExit): # noinspection PyProtectedMember assert export._get_tracker_store(available_endpoints)
def test_get_event_broker_from_endpoint_config_error_exit(tmp_path: Path): # write config without event broker to file endpoints_path = write_endpoint_config_to_yaml( tmp_path, {"tracker_store": {"type": "sql"}} ) available_endpoints = rasa_core_utils.read_endpoints_from_path(endpoints_path) with pytest.raises(SystemExit): assert export._get_event_broker(available_endpoints)
def prepare_namespace_and_mocked_tracker_store_with_events( temporary_path: Path, monkeypatch: MonkeyPatch ) -> Tuple[List[UserUttered], argparse.Namespace]: endpoints_path = write_endpoint_config_to_yaml( temporary_path, { "event_broker": { "type": "pika" }, "tracker_store": { "type": "sql" } }, ) # export these conversation IDs all_conversation_ids = ["id-1", "id-2", "id-3"] requested_conversation_ids = ["id-1", "id-2"] # create namespace with a set of cmdline arguments namespace = argparse.Namespace( endpoints=endpoints_path, conversation_ids=",".join(requested_conversation_ids), minimum_timestamp=1.0, maximum_timestamp=10.0, ) # prepare events from different senders and different timestamps events = [ random_user_uttered_event(timestamp) for timestamp in [1, 2, 3, 4, 11, 5] ] events_for_conversation_id = { all_conversation_ids[0]: [events[0], events[1]], all_conversation_ids[1]: [events[2], events[3], events[4]], all_conversation_ids[2]: [events[5]], } def _get_tracker(conversation_id: Text) -> DialogueStateTracker: return DialogueStateTracker.from_events( conversation_id, events_for_conversation_id[conversation_id]) # mock tracker store tracker_store = Mock() tracker_store.keys.return_value = all_conversation_ids tracker_store.retrieve.side_effect = _get_tracker monkeypatch.setattr(export, "_get_tracker_store", lambda _: tracker_store) return events, namespace
def test_markers_cli_results_save_correctly( marker_sqlite_tracker: Tuple[SQLTrackerStore, Text], tmp_path: Path): _, db_path = marker_sqlite_tracker endpoints_path = write_endpoint_config_to_yaml( tmp_path, { "tracker_store": { "type": "sql", "db": db_path.replace("\\", "\\\\") }, }, ) markers_path = write_markers_config_to_yaml(tmp_path, { "marker1": { "slot_was_set": "2" }, "marker2": { "slot_was_set": "7" } }) results_path = tmp_path / "results.csv" stats_file_prefix = tmp_path / "statistics" rasa.cli.evaluate._run_markers( seed=None, count=10, endpoint_config=endpoints_path, strategy="first_n", domain_path=None, config=markers_path, output_filename=results_path, stats_file_prefix=stats_file_prefix, ) for expected_output in [ results_path, tmp_path / ("statistics" + STATS_SESSION_SUFFIX), tmp_path / ("statistics" + STATS_OVERALL_SUFFIX), ]: with expected_output.open(mode="r") as results: result_reader = csv.DictReader(results) # Loop over entire file to ensure nothing in the file causes any errors for _ in result_reader: continue
def test_get_event_broker_and_tracker_store_from_endpoint_config( tmp_path: Path): # write valid config to file endpoints_path = write_endpoint_config_to_yaml(tmp_path, { "event_broker": { "type": "sql" }, "tracker_store": { "type": "sql" } }) available_endpoints = rasa_core_utils.read_endpoints_from_path( endpoints_path) # fetching the event broker is successful assert export._get_event_broker(available_endpoints) assert export._get_tracker_store(available_endpoints)
def test_read_endpoints_from_path(tmp_path: Path): # write valid config to file endpoints_path = write_endpoint_config_to_yaml( tmp_path, {"event_broker": {"type": "pika"}, "tracker_store": {"type": "sql"}} ) # noinspection PyProtectedMember available_endpoints = utils.read_endpoints_from_path(endpoints_path) # assert event broker and tracker store are valid, others are not assert available_endpoints.tracker_store and available_endpoints.event_broker assert not all( ( available_endpoints.lock_store, available_endpoints.nlg, available_endpoints.action, available_endpoints.model, available_endpoints.nlu, ) )