def test_clean_domain_deprecated_templates(): domain_path = "data/test_domains/default_deprecated_templates.yml" cleaned = Domain.load(domain_path).cleaned_domain() expected = { "intents": [ {"greet": {USE_ENTITIES_KEY: ["name"]}}, {"default": {IGNORE_ENTITIES_KEY: ["unrelated_recognized_entity"]}}, {"goodbye": {USE_ENTITIES_KEY: []}}, {"thank": {USE_ENTITIES_KEY: []}}, "ask", {"why": {USE_ENTITIES_KEY: []}}, "pure_intent", ], "entities": ["name", "unrelated_recognized_entity", "other"], "responses": { "utter_greet": [{"text": "hey there!"}], "utter_goodbye": [{"text": "goodbye :("}], "utter_default": [{"text": "default message"}], }, "actions": ["utter_default", "utter_greet", "utter_goodbye"], } expected = Domain.from_dict(expected) actual = Domain.from_dict(cleaned) assert actual.as_dict() == expected.as_dict()
def test_domain_from_dict_does_not_change_input(): input_before = { "intents": [ { "greet": { USE_ENTITIES_KEY: ["name"] } }, { "default": { IGNORE_ENTITIES_KEY: ["unrelated_recognized_entity"] } }, { "goodbye": { USE_ENTITIES_KEY: None } }, { "thank": { USE_ENTITIES_KEY: False } }, { "ask": { USE_ENTITIES_KEY: True } }, { "why": { USE_ENTITIES_KEY: [] } }, "pure_intent", ], "entities": ["name", "unrelated_recognized_entity", "other"], "slots": { "name": { "type": "text" } }, "responses": { "utter_greet": [{ "text": "hey there {name}!" }], "utter_goodbye": [{ "text": "goodbye 😢" }, { "text": "bye bye 😢" }], "utter_default": [{ "text": "default message" }], }, } input_after = copy.deepcopy(input_before) Domain.from_dict(input_after) assert input_after == input_before
def test_clean_domain(): domain_path = "data/test_domains/default_unfeaturized_entities.yml" cleaned = Domain.load(domain_path).cleaned_domain() expected = { "intents": [ {"greet": {"use_entities": ["name"]}}, {"default": {"ignore_entities": ["unrelated_recognized_entity"]}}, {"goodbye": {"use_entities": []}}, {"thank": {"use_entities": []}}, "ask", {"why": {"use_entities": []}}, "pure_intent", ], "entities": ["name", "other", "unrelated_recognized_entity"], "templates": { "utter_greet": [{"text": "hey there!"}], "utter_goodbye": [{"text": "goodbye :("}], "utter_default": [{"text": "default message"}], }, "actions": ["utter_default", "utter_goodbye", "utter_greet"], } expected = Domain.from_dict(expected) actual = Domain.from_dict(cleaned) assert hash(actual) == hash(expected)
def test_extract_requested_slot_from_entity( mapping_not_intent: Optional[Text], mapping_intent: Optional[Text], mapping_role: Optional[Text], mapping_group: Optional[Text], entities: List[Dict[Text, Any]], intent: Text, expected_slot_values: Dict[Text, Text], ): """Test extraction of a slot value from entity with the different restrictions.""" form_name = "some form" form = FormAction(form_name, None) mapping = form.from_entity( entity="some_entity", role=mapping_role, group=mapping_group, intent=mapping_intent, not_intent=mapping_not_intent, ) domain = Domain.from_dict({"forms": [{form_name: {"some_slot": [mapping]}}]}) tracker = DialogueStateTracker.from_events( "default", [ SlotSet(REQUESTED_SLOT, "some_slot"), UserUttered( "bla", intent={"name": intent, "confidence": 1.0}, entities=entities ), ], ) slot_values = form.extract_requested_slot(tracker, domain) assert slot_values == expected_slot_values
def test_form_without_form_policy(policy_config: Dict[Text, List[Text]]): with pytest.raises(InvalidDomain) as execinfo: Agent( domain=Domain.from_dict({"forms": ["restaurant_form"]}), policies=PolicyEnsemble.from_dict(policy_config), ) assert "haven't added the FormPolicy" in str(execinfo.value)
def test_trigger_without_mapping_policy(domain, policy_config): with pytest.raises(InvalidDomain) as execinfo: Agent( domain=Domain.from_dict(domain), policies=PolicyEnsemble.from_dict(policy_config), ) assert "haven't added the MappingPolicy" in str(execinfo.value)
async def get_story(request, story_id): from rasa.core.domain import Domain story = _story_service(request).fetch_story(story_id) content_type = request.headers.get("Accept") if content_type == "text/vnd.graphviz": project_id = rasa_x_utils.default_arg(request, "project_id", config.project_name) domain_dict = _domain_service(request).get_or_create_domain( project_id) domain = Domain.from_dict(domain_dict) visualization = await _story_service(request).visualize_stories( [story], domain) if visualization: return response.text(visualization) else: return rasa_x_utils.error( HTTPStatus.NOT_ACCEPTABLE, "VisualizationNotAvailable", "Cannot produce a visualization for the requested story", ) else: if story: return response.json(story) return rasa_x_utils.error( HTTPStatus.NOT_FOUND, "StoryNotFound", f"Story for id {story_id} could not be found", )
def test_extract_requested_slot_mapping_does_not_apply(slot_mapping: Dict): form_name = "some_form" entity_name = "some_slot" form = FormAction(form_name, None) domain = Domain.from_dict( {"forms": [{ form_name: { entity_name: [slot_mapping] } }]}) tracker = DialogueStateTracker.from_events( "default", [ SlotSet(REQUESTED_SLOT, "some_slot"), UserUttered( "bla", intent={ "name": "greet", "confidence": 1.0 }, entities=[{ "entity": entity_name, "value": "some_value" }], ), ActionExecuted(ACTION_LISTEN_NAME), ], ) slot_values = form.extract_requested_slot(tracker, domain) # check that the value was not extracted for incorrect intent assert slot_values == {}
async def _write_domain_to_file( domain_path: Text, evts: List[Dict[Text, Any]], endpoint: EndpointConfig ) -> None: """Write an updated domain file to the file path.""" domain = await retrieve_domain(endpoint) old_domain = Domain.from_dict(domain) messages = _collect_messages(evts) actions = _collect_actions(evts) # TODO for now there is no way to distinguish between action and form intent_properties = Domain.collect_intent_properties( _intents_from_messages(messages)) collected_actions = list({e["name"] for e in actions if e["name"] not in default_action_names()}) new_domain = Domain( intent_properties=intent_properties, entities=_entities_from_messages(messages), slots=[], templates={}, action_names=collected_actions, form_names=[]) old_domain.merge(new_domain).persist_clean(domain_path)
async def get_story_steps( story_string: Text, domain: Dict[Text, Any] = None) -> List[StoryStep]: """Given story md string reads the contained stories. Also checks if the intents in the stories are in the provided domain. For each intent not present in the domain, a UserWarning is issued. Returns a list of StorySteps in the story if the story is valid and [] otherwise. """ # domain is not needed in `StoryFileReader` when parsing stories, # but if none is provided there will be a UserWarning for each intent if not domain: domain = {} else: # Make a copy because RasaDomain.from_dict(domain) might change its input # domain. This will be fixed with Rasa 2.0. domain = copy.deepcopy(domain) domain = RasaDomain.from_dict(domain) reader = StoryFileReader(interpreter=RegexInterpreter(), domain=domain) try: # just split on newlines lines = story_string.split("\n") return await reader.process_lines(lines) except (AttributeError, ValueError) as e: raise StoryParseError("Invalid story format. Failed to parse " "'{}'\nError: {}".format(story_string, e))
def test_two_stage_fallback_without_deny_suggestion(domain, policy_config): with pytest.raises(InvalidDomain) as execinfo: Agent( domain=Domain.from_dict(domain), policies=PolicyEnsemble.from_dict(policy_config), ) assert "The intent 'out_of_scope' must be present" in str(execinfo.value)
def test_form_without_form_policy(domain: Dict[Text, Any], policy_config: Dict[Text, Any]): with pytest.raises(InvalidDomain) as execinfo: Agent( domain=Domain.from_dict(domain), policies=PolicyEnsemble.from_dict(policy_config), ) assert "haven't added the FormPolicy" in str(execinfo.value)
def test_domain_as_dict_with_session_config(): session_config = SessionConfig(123, False) domain = Domain.empty() domain.session_config = session_config serialized = domain.as_dict() deserialized = Domain.from_dict(serialized) assert deserialized.session_config == session_config
async def get_domain_warnings( self, project_id: Text = config.project_name ) -> Optional[Tuple[Dict[Text, Dict[Text, List[Text]]], int]]: """Get domain warnings. Args: project_id: The project id of the domain. Returns: Dict of domain warnings and the total count of elements. """ domain = self._get_domain(project_id) if domain: from rasax.community.services.data_service import DataService from rasax.community.services.nlg_service import NlgService from rasax.community.services.story_service import StoryService domain_object = RasaDomain.from_dict(domain.as_dict()) training_data = DataService( self.session).get_nlu_training_data_object( project_id=project_id) # actions are response names and story bot actions actions = NlgService(self.session).fetch_all_response_names() # intents are training data intents and story intents intents = training_data.intents # entities are training data entities without `extractor` attribute entity_examples = training_data.entity_examples entities = self._get_entities_from_training_data(entity_examples) # slots are simply story slots slots = set() story_events = await StoryService( self.session).fetch_domain_items_from_stories() if story_events: actions.update(story_events[0]) intents.update(story_events[1]) slots.update(story_events[2]) entities.update(story_events[3]) # exclude unfeaturized slots from warnings slots = self._remove_unfeaturized_slots(slots, domain_object) domain_warnings = self._domain_warnings_as_list( domain_object, intents, entities, actions, slots) return domain_warnings, self._count_total_warnings(domain_warnings) return None
def dump_cleaned_domain_yaml(domain: Dict[Text, Any]) -> Optional[Text]: """Take a domain as a dictionary, cleans it and returns it as a yaml string. Args: domain: Domain as a dictionary. Returns: The cleaned domain as a yaml string. """ cleaned_domain = RasaDomain.from_dict(domain).cleaned_domain() return dump_yaml(cleaned_domain)
async def model_fingerprint( file_importer: "TrainingDataImporter") -> Fingerprint: """Create a model fingerprint from its used configuration and training data. Args: file_importer: File importer which provides the training data and model config. Returns: The fingerprint. """ from rasa.core.domain import Domain import rasa import time # bf mod # config = await file_importer.get_config() domain = await file_importer.get_domain() # stories = await file_importer.get_stories() stories_hash = await file_importer.get_stories_hash() nlu_data = await file_importer.get_nlu_data() nlu_config = await file_importer.get_nlu_config() core_config = await file_importer.get_core_config() domain_dict = domain.as_dict() templates = domain_dict.pop("responses") domain_without_nlg = Domain.from_dict(domain_dict) return { FINGERPRINT_CONFIG_KEY: _get_hash_of_config(core_config, exclude_keys=CONFIG_MANDATORY_KEYS), FINGERPRINT_CONFIG_CORE_KEY: _get_hash_of_config(core_config, include_keys=CONFIG_MANDATORY_KEYS_CORE), FINGERPRINT_CONFIG_NLU_KEY: { lang: _get_hash_of_config(config, include_keys=CONFIG_MANDATORY_KEYS_NLU) for (lang, config) in nlu_config.items() }, FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY: hash(domain_without_nlg), FINGERPRINT_NLG_KEY: get_dict_hash(templates), FINGERPRINT_NLU_DATA_KEY: {lang: hash(nlu_data[lang]) for lang in nlu_data}, FINGERPRINT_STORIES_KEY: stories_hash, FINGERPRINT_TRAINED_AT_KEY: time.time(), FINGERPRINT_RASA_VERSION_KEY: rasa.__version__, }
async def train(self): """Train the engine. """ nltk.download('punkt') lang = self.config['language'] if not os.path.exists('data/' + self.config['skill-id']): _LOGGER.info("Starting Skill training.") _LOGGER.info("Generating stories.") data, domain_data, stories = await GenerateStories.run( self.config['skill-id'], self.config['language'], self.asm) training_data = TrainingData(training_examples=data) nlu_config = RasaNLUModelConfig({ "language": lang, "pipeline": self.config['pipeline'], "data": None }) trainer = Trainer(nlu_config, None, True) _LOGGER.info("Training Arcus NLU") trainer.train(training_data) trainer.persist("data/" + self.config['skill-id'], None, 'nlu') # Rasa core domain = Domain.from_dict(domain_data) reader = StoryFileReader(domain, RegexInterpreter(), None, False) story_steps = await reader.process_lines(stories) graph = StoryGraph(story_steps) g = TrainingDataGenerator( graph, domain, remove_duplicates=True, unique_last_num_states=None, augmentation_factor=20, tracker_limit=None, use_story_concatenation=True, debug_plots=False, ) training_trackers = g.generate() policy_list = SimplePolicyEnsemble.from_dict( {"policies": self.config['policies']}) policy_ensemble = SimplePolicyEnsemble(policy_list) _LOGGER.info("Training Arcus Core") policy_ensemble.train(training_trackers, domain) policy_ensemble.persist( "data/" + self.config['skill-id'] + "/core", False) domain.persist("data/" + self.config['skill-id'] + "/core/model") domain.persist_specification("data/" + self.config['skill-id'] + "/core")
def test_invalid_slot_mapping(): form_name = "my_form" form = FormAction(form_name, None) slot_name = "test" tracker = DialogueStateTracker.from_events( "sender", [SlotSet(REQUESTED_SLOT, slot_name)] ) domain = Domain.from_dict( {"forms": [{form_name: {slot_name: [{"type": "invalid"}]}}]} ) with pytest.raises(ValueError): form.extract_requested_slot(tracker, domain)
def dump_domain(self, filename: Optional[Text] = None, project_id: Text = config.project_name): """Dump domain to `filename` in yml format.""" domain = self.get_domain(project_id) if not domain: return if not filename: filename = domain.get("path") or config.default_domain_path cleaned_domain = RasaDomain.from_dict(domain).cleaned_domain() domain_path = utils.get_project_directory() / filename dump_yaml_to_file(domain_path, cleaned_domain)
async def get_stories(request): from rasa.core.domain import Domain text_query = rasa_x_utils.default_arg(request, "q", None) fields = rasa_x_utils.fields_arg(request, {"name", "annotation.user", "id"}) id_query = rasa_x_utils.list_arg(request, "id") distinct = rasa_x_utils.bool_arg(request, "distinct", True) stories = _story_service(request).fetch_stories(text_query, fields, id_query=id_query, distinct=distinct) content_type = request.headers.get("Accept") if content_type == "text/vnd.graphviz": project_id = rasa_x_utils.default_arg(request, "project_id", config.project_name) domain_dict = _domain_service(request).get_or_create_domain( project_id) domain = Domain.from_dict(domain_dict) visualization = await _story_service(request).visualize_stories( stories, domain) if visualization: return response.text(visualization) else: return rasa_x_utils.error( HTTPStatus.NOT_ACCEPTABLE, "VisualizationNotAvailable", "Cannot produce a visualization for the requested stories", ) elif content_type == "text/markdown": markdown = _story_service(request).get_stories_markdown(stories) return response.text( markdown, content_type="text/markdown", headers={ "Content-Disposition": "attachment;filename=stories.md" }, ) else: return response.json(stories, headers={"X-Total-Count": len(stories)})
async def test_fingerprinting_additional_action(project: Text): importer = _project_files(project) old_fingerprint = await model_fingerprint(importer) old_domain = await importer.get_domain() domain_with_new_action = old_domain.as_dict() domain_with_new_action[KEY_RESPONSES]["utter_new"] = [{"text": "hi"}] domain_with_new_action = Domain.from_dict(domain_with_new_action) importer.get_domain = asyncio.coroutine(lambda: domain_with_new_action) new_fingerprint = await model_fingerprint(importer) assert (old_fingerprint[FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY] != new_fingerprint[FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY]) assert old_fingerprint[FINGERPRINT_NLG_KEY] != new_fingerprint[ FINGERPRINT_NLG_KEY]
async def test_ask_for_slot(domain: Dict, expected_action: Text, monkeypatch: MonkeyPatch): slot_name = "sun" action_from_name = Mock(return_value=action.ActionListen()) monkeypatch.setattr(action, action.action_from_name.__name__, action_from_name) form = FormAction("my_form", None) await form._ask_for_slot( Domain.from_dict(domain), None, None, slot_name, DialogueStateTracker.from_events("dasd", []), ) action_from_name.assert_called_once_with(expected_action, None, ANY)
async def test_trigger_slot_mapping_applies( trigger_slot_mapping: Dict, expected_value: Text ): form_name = "some_form" entity_name = "some_slot" slot_filled_by_trigger_mapping = "other_slot" form = FormAction(form_name, None) domain = Domain.from_dict( { "forms": [ { form_name: { entity_name: [ { "type": "from_entity", "entity": entity_name, "intent": "some_intent", } ], slot_filled_by_trigger_mapping: [trigger_slot_mapping], } } ] } ) tracker = DialogueStateTracker.from_events( "default", [ SlotSet(REQUESTED_SLOT, "some_slot"), UserUttered( "bla", intent={"name": "greet", "confidence": 1.0}, entities=[{"entity": entity_name, "value": "some_value"}], ), ActionExecuted(ACTION_LISTEN_NAME), ], ) slot_values = form.extract_other_slots(tracker, domain) assert slot_values == {slot_filled_by_trigger_mapping: expected_value}
async def test_fingerprinting_changed_response_text(project: Text): importer = _project_files(project) old_fingerprint = await model_fingerprint(importer) old_domain = await importer.get_domain() # Change NLG content but keep actions the same domain_with_changed_nlg = old_domain.as_dict() domain_with_changed_nlg[KEY_RESPONSES]["utter_greet"].append( {"text": "hi"}) domain_with_changed_nlg = Domain.from_dict(domain_with_changed_nlg) importer.get_domain = asyncio.coroutine(lambda: domain_with_changed_nlg) new_fingerprint = await model_fingerprint(importer) assert (old_fingerprint[FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY] == new_fingerprint[FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY]) assert old_fingerprint[FINGERPRINT_NLG_KEY] != new_fingerprint[ FINGERPRINT_NLG_KEY]
def test_extract_other_slots_with_entity( some_other_slot_mapping: List[Dict[Text, Any]], some_slot_mapping: List[Dict[Text, Any]], entities: List[Dict[Text, Any]], intent: Text, expected_slot_values: Dict[Text, Text], ): """Test extraction of other not requested slots values from entities.""" form_name = "some_form" form = FormAction(form_name, None) domain = Domain.from_dict( { "forms": [ { form_name: { "some_other_slot": some_other_slot_mapping, "some_slot": some_slot_mapping, } } ] } ) tracker = DialogueStateTracker.from_events( "default", [ SlotSet(REQUESTED_SLOT, "some_slot"), UserUttered( "bla", intent={"name": intent, "confidence": 1.0}, entities=entities ), ActionExecuted(ACTION_LISTEN_NAME), ], ) slot_values = form.extract_other_slots(tracker, domain) # check that the value was extracted for non requested slot assert slot_values == expected_slot_values
async def test_interactive_domain_persistence(mock_endpoint, tmpdir): # Test method interactive._write_domain_to_file tracker_dump = "data/test_trackers/tracker_moodbot.json" tracker_json = rasa.utils.io.read_json_file(tracker_dump) events = tracker_json.get("events", []) domain_path = tmpdir.join("interactive_domain_save.yml").strpath url = f"{mock_endpoint.url}/domain" with aioresponses() as mocked: mocked.get(url, payload={}) serialised_domain = await interactive.retrieve_domain(mock_endpoint) old_domain = Domain.from_dict(serialised_domain) await interactive._write_domain_to_file(domain_path, events, old_domain) saved_domain = rasa.utils.io.read_config_file(domain_path) for default_action in action.default_actions(): assert default_action.name() not in saved_domain["actions"]
def test_forms_with_suited_policy(policy_config: Dict[Text, List[Text]]): # Doesn't raise Agent( domain=Domain.from_dict({"forms": ["restaurant_form"]}), policies=PolicyEnsemble.from_dict(policy_config), )
def test_add_default_intents(domain: Dict): domain = Domain.from_dict(domain) assert all(intent_name in domain.intents for intent_name in DEFAULT_INTENTS)