Beispiel #1
0
def test_clean_domain_deprecated_templates():
    domain_path = "data/test_domains/default_deprecated_templates.yml"
    cleaned = Domain.load(domain_path).cleaned_domain()

    expected = {
        "intents": [
            {"greet": {USE_ENTITIES_KEY: ["name"]}},
            {"default": {IGNORE_ENTITIES_KEY: ["unrelated_recognized_entity"]}},
            {"goodbye": {USE_ENTITIES_KEY: []}},
            {"thank": {USE_ENTITIES_KEY: []}},
            "ask",
            {"why": {USE_ENTITIES_KEY: []}},
            "pure_intent",
        ],
        "entities": ["name", "unrelated_recognized_entity", "other"],
        "responses": {
            "utter_greet": [{"text": "hey there!"}],
            "utter_goodbye": [{"text": "goodbye :("}],
            "utter_default": [{"text": "default message"}],
        },
        "actions": ["utter_default", "utter_greet", "utter_goodbye"],
    }

    expected = Domain.from_dict(expected)
    actual = Domain.from_dict(cleaned)

    assert actual.as_dict() == expected.as_dict()
Beispiel #2
0
def test_domain_from_dict_does_not_change_input():
    input_before = {
        "intents": [
            {
                "greet": {
                    USE_ENTITIES_KEY: ["name"]
                }
            },
            {
                "default": {
                    IGNORE_ENTITIES_KEY: ["unrelated_recognized_entity"]
                }
            },
            {
                "goodbye": {
                    USE_ENTITIES_KEY: None
                }
            },
            {
                "thank": {
                    USE_ENTITIES_KEY: False
                }
            },
            {
                "ask": {
                    USE_ENTITIES_KEY: True
                }
            },
            {
                "why": {
                    USE_ENTITIES_KEY: []
                }
            },
            "pure_intent",
        ],
        "entities": ["name", "unrelated_recognized_entity", "other"],
        "slots": {
            "name": {
                "type": "text"
            }
        },
        "responses": {
            "utter_greet": [{
                "text": "hey there {name}!"
            }],
            "utter_goodbye": [{
                "text": "goodbye 😢"
            }, {
                "text": "bye bye 😢"
            }],
            "utter_default": [{
                "text": "default message"
            }],
        },
    }

    input_after = copy.deepcopy(input_before)
    Domain.from_dict(input_after)

    assert input_after == input_before
def test_clean_domain():
    domain_path = "data/test_domains/default_unfeaturized_entities.yml"
    cleaned = Domain.load(domain_path).cleaned_domain()

    expected = {
        "intents": [
            {"greet": {"use_entities": ["name"]}},
            {"default": {"ignore_entities": ["unrelated_recognized_entity"]}},
            {"goodbye": {"use_entities": []}},
            {"thank": {"use_entities": []}},
            "ask",
            {"why": {"use_entities": []}},
            "pure_intent",
        ],
        "entities": ["name", "other", "unrelated_recognized_entity"],
        "templates": {
            "utter_greet": [{"text": "hey there!"}],
            "utter_goodbye": [{"text": "goodbye :("}],
            "utter_default": [{"text": "default message"}],
        },
        "actions": ["utter_default", "utter_goodbye", "utter_greet"],
    }

    expected = Domain.from_dict(expected)
    actual = Domain.from_dict(cleaned)

    assert hash(actual) == hash(expected)
Beispiel #4
0
def test_extract_requested_slot_from_entity(
    mapping_not_intent: Optional[Text],
    mapping_intent: Optional[Text],
    mapping_role: Optional[Text],
    mapping_group: Optional[Text],
    entities: List[Dict[Text, Any]],
    intent: Text,
    expected_slot_values: Dict[Text, Text],
):
    """Test extraction of a slot value from entity with the different restrictions."""

    form_name = "some form"
    form = FormAction(form_name, None)

    mapping = form.from_entity(
        entity="some_entity",
        role=mapping_role,
        group=mapping_group,
        intent=mapping_intent,
        not_intent=mapping_not_intent,
    )
    domain = Domain.from_dict({"forms": [{form_name: {"some_slot": [mapping]}}]})

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla", intent={"name": intent, "confidence": 1.0}, entities=entities
            ),
        ],
    )

    slot_values = form.extract_requested_slot(tracker, domain)
    assert slot_values == expected_slot_values
Beispiel #5
0
def test_form_without_form_policy(policy_config: Dict[Text, List[Text]]):
    with pytest.raises(InvalidDomain) as execinfo:
        Agent(
            domain=Domain.from_dict({"forms": ["restaurant_form"]}),
            policies=PolicyEnsemble.from_dict(policy_config),
        )
    assert "haven't added the FormPolicy" in str(execinfo.value)
Beispiel #6
0
def test_trigger_without_mapping_policy(domain, policy_config):
    with pytest.raises(InvalidDomain) as execinfo:
        Agent(
            domain=Domain.from_dict(domain),
            policies=PolicyEnsemble.from_dict(policy_config),
        )
    assert "haven't added the MappingPolicy" in str(execinfo.value)
    async def get_story(request, story_id):
        from rasa.core.domain import Domain

        story = _story_service(request).fetch_story(story_id)

        content_type = request.headers.get("Accept")
        if content_type == "text/vnd.graphviz":
            project_id = rasa_x_utils.default_arg(request, "project_id",
                                                  config.project_name)
            domain_dict = _domain_service(request).get_or_create_domain(
                project_id)
            domain = Domain.from_dict(domain_dict)

            visualization = await _story_service(request).visualize_stories(
                [story], domain)

            if visualization:
                return response.text(visualization)
            else:
                return rasa_x_utils.error(
                    HTTPStatus.NOT_ACCEPTABLE,
                    "VisualizationNotAvailable",
                    "Cannot produce a visualization for the requested story",
                )
        else:
            if story:
                return response.json(story)

        return rasa_x_utils.error(
            HTTPStatus.NOT_FOUND,
            "StoryNotFound",
            f"Story for id {story_id} could not be found",
        )
Beispiel #8
0
def test_extract_requested_slot_mapping_does_not_apply(slot_mapping: Dict):
    form_name = "some_form"
    entity_name = "some_slot"
    form = FormAction(form_name, None)

    domain = Domain.from_dict(
        {"forms": [{
            form_name: {
                entity_name: [slot_mapping]
            }
        }]})

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla",
                intent={
                    "name": "greet",
                    "confidence": 1.0
                },
                entities=[{
                    "entity": entity_name,
                    "value": "some_value"
                }],
            ),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    slot_values = form.extract_requested_slot(tracker, domain)
    # check that the value was not extracted for incorrect intent
    assert slot_values == {}
Beispiel #9
0
async def _write_domain_to_file(
    domain_path: Text,
    evts: List[Dict[Text, Any]],
    endpoint: EndpointConfig
) -> None:
    """Write an updated domain file to the file path."""

    domain = await retrieve_domain(endpoint)
    old_domain = Domain.from_dict(domain)

    messages = _collect_messages(evts)
    actions = _collect_actions(evts)

    # TODO for now there is no way to distinguish between action and form
    intent_properties = Domain.collect_intent_properties(
        _intents_from_messages(messages))

    collected_actions = list({e["name"]
                              for e in actions
                              if e["name"] not in default_action_names()})

    new_domain = Domain(
        intent_properties=intent_properties,
        entities=_entities_from_messages(messages),
        slots=[],
        templates={},
        action_names=collected_actions,
        form_names=[])

    old_domain.merge(new_domain).persist_clean(domain_path)
Beispiel #10
0
    async def get_story_steps(
            story_string: Text,
            domain: Dict[Text, Any] = None) -> List[StoryStep]:
        """Given story md string reads the contained stories.

        Also checks if the intents in the stories are in the provided domain. For each
        intent not present in the domain, a UserWarning is issued.

        Returns a list of StorySteps in the story if the story is valid and
        [] otherwise.
        """
        # domain is not needed in `StoryFileReader` when parsing stories,
        # but if none is provided there will be a UserWarning for each intent
        if not domain:
            domain = {}
        else:
            # Make a copy because RasaDomain.from_dict(domain) might change its input
            # domain. This will be fixed with Rasa 2.0.
            domain = copy.deepcopy(domain)
        domain = RasaDomain.from_dict(domain)
        reader = StoryFileReader(interpreter=RegexInterpreter(), domain=domain)
        try:
            # just split on newlines
            lines = story_string.split("\n")
            return await reader.process_lines(lines)
        except (AttributeError, ValueError) as e:
            raise StoryParseError("Invalid story format. Failed to parse "
                                  "'{}'\nError: {}".format(story_string, e))
Beispiel #11
0
def test_two_stage_fallback_without_deny_suggestion(domain, policy_config):
    with pytest.raises(InvalidDomain) as execinfo:
        Agent(
            domain=Domain.from_dict(domain),
            policies=PolicyEnsemble.from_dict(policy_config),
        )
    assert "The intent 'out_of_scope' must be present" in str(execinfo.value)
Beispiel #12
0
def test_form_without_form_policy(domain: Dict[Text, Any],
                                  policy_config: Dict[Text, Any]):
    with pytest.raises(InvalidDomain) as execinfo:
        Agent(
            domain=Domain.from_dict(domain),
            policies=PolicyEnsemble.from_dict(policy_config),
        )
    assert "haven't added the FormPolicy" in str(execinfo.value)
Beispiel #13
0
def test_domain_as_dict_with_session_config():
    session_config = SessionConfig(123, False)
    domain = Domain.empty()
    domain.session_config = session_config

    serialized = domain.as_dict()
    deserialized = Domain.from_dict(serialized)

    assert deserialized.session_config == session_config
Beispiel #14
0
    async def get_domain_warnings(
        self,
        project_id: Text = config.project_name
    ) -> Optional[Tuple[Dict[Text, Dict[Text, List[Text]]], int]]:
        """Get domain warnings.

        Args:
            project_id: The project id of the domain.

        Returns:
            Dict of domain warnings and the total count of elements.
        """
        domain = self._get_domain(project_id)

        if domain:
            from rasax.community.services.data_service import DataService
            from rasax.community.services.nlg_service import NlgService
            from rasax.community.services.story_service import StoryService

            domain_object = RasaDomain.from_dict(domain.as_dict())

            training_data = DataService(
                self.session).get_nlu_training_data_object(
                    project_id=project_id)

            # actions are response names and story bot actions
            actions = NlgService(self.session).fetch_all_response_names()

            # intents are training data intents and story intents
            intents = training_data.intents

            # entities are training data entities without `extractor` attribute
            entity_examples = training_data.entity_examples
            entities = self._get_entities_from_training_data(entity_examples)

            # slots are simply story slots
            slots = set()

            story_events = await StoryService(
                self.session).fetch_domain_items_from_stories()

            if story_events:
                actions.update(story_events[0])
                intents.update(story_events[1])
                slots.update(story_events[2])
                entities.update(story_events[3])

            # exclude unfeaturized slots from warnings
            slots = self._remove_unfeaturized_slots(slots, domain_object)

            domain_warnings = self._domain_warnings_as_list(
                domain_object, intents, entities, actions, slots)

            return domain_warnings, self._count_total_warnings(domain_warnings)

        return None
Beispiel #15
0
    def dump_cleaned_domain_yaml(domain: Dict[Text, Any]) -> Optional[Text]:
        """Take a domain as a dictionary, cleans it and returns it as a yaml string.

        Args:
            domain: Domain as a dictionary.

        Returns:
            The cleaned domain as a yaml string.
        """
        cleaned_domain = RasaDomain.from_dict(domain).cleaned_domain()
        return dump_yaml(cleaned_domain)
Beispiel #16
0
async def model_fingerprint(
        file_importer: "TrainingDataImporter") -> Fingerprint:
    """Create a model fingerprint from its used configuration and training data.

    Args:
        file_importer: File importer which provides the training data and model config.

    Returns:
        The fingerprint.

    """
    from rasa.core.domain import Domain

    import rasa
    import time

    # bf mod
    # config = await file_importer.get_config()
    domain = await file_importer.get_domain()
    # stories = await file_importer.get_stories()
    stories_hash = await file_importer.get_stories_hash()
    nlu_data = await file_importer.get_nlu_data()

    nlu_config = await file_importer.get_nlu_config()
    core_config = await file_importer.get_core_config()
    domain_dict = domain.as_dict()
    templates = domain_dict.pop("responses")
    domain_without_nlg = Domain.from_dict(domain_dict)

    return {
        FINGERPRINT_CONFIG_KEY:
        _get_hash_of_config(core_config, exclude_keys=CONFIG_MANDATORY_KEYS),
        FINGERPRINT_CONFIG_CORE_KEY:
        _get_hash_of_config(core_config,
                            include_keys=CONFIG_MANDATORY_KEYS_CORE),
        FINGERPRINT_CONFIG_NLU_KEY: {
            lang: _get_hash_of_config(config,
                                      include_keys=CONFIG_MANDATORY_KEYS_NLU)
            for (lang, config) in nlu_config.items()
        },
        FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY:
        hash(domain_without_nlg),
        FINGERPRINT_NLG_KEY:
        get_dict_hash(templates),
        FINGERPRINT_NLU_DATA_KEY:
        {lang: hash(nlu_data[lang])
         for lang in nlu_data},
        FINGERPRINT_STORIES_KEY:
        stories_hash,
        FINGERPRINT_TRAINED_AT_KEY:
        time.time(),
        FINGERPRINT_RASA_VERSION_KEY:
        rasa.__version__,
    }
Beispiel #17
0
    async def train(self):
        """Train the engine.
        """
        nltk.download('punkt')
        lang = self.config['language']
        if not os.path.exists('data/' + self.config['skill-id']):
            _LOGGER.info("Starting Skill training.")
            _LOGGER.info("Generating stories.")
            data, domain_data, stories = await GenerateStories.run(
                self.config['skill-id'], self.config['language'], self.asm)
            training_data = TrainingData(training_examples=data)
            nlu_config = RasaNLUModelConfig({
                "language": lang,
                "pipeline": self.config['pipeline'],
                "data": None
            })

            trainer = Trainer(nlu_config, None, True)
            _LOGGER.info("Training Arcus NLU")
            trainer.train(training_data)
            trainer.persist("data/" + self.config['skill-id'], None, 'nlu')

            # Rasa core
            domain = Domain.from_dict(domain_data)

            reader = StoryFileReader(domain, RegexInterpreter(), None, False)
            story_steps = await reader.process_lines(stories)
            graph = StoryGraph(story_steps)

            g = TrainingDataGenerator(
                graph,
                domain,
                remove_duplicates=True,
                unique_last_num_states=None,
                augmentation_factor=20,
                tracker_limit=None,
                use_story_concatenation=True,
                debug_plots=False,
            )

            training_trackers = g.generate()
            policy_list = SimplePolicyEnsemble.from_dict(
                {"policies": self.config['policies']})
            policy_ensemble = SimplePolicyEnsemble(policy_list)

            _LOGGER.info("Training Arcus Core")
            policy_ensemble.train(training_trackers, domain)
            policy_ensemble.persist(
                "data/" + self.config['skill-id'] + "/core", False)
            domain.persist("data/" + self.config['skill-id'] + "/core/model")
            domain.persist_specification("data/" + self.config['skill-id'] +
                                         "/core")
Beispiel #18
0
def test_invalid_slot_mapping():
    form_name = "my_form"
    form = FormAction(form_name, None)
    slot_name = "test"
    tracker = DialogueStateTracker.from_events(
        "sender", [SlotSet(REQUESTED_SLOT, slot_name)]
    )

    domain = Domain.from_dict(
        {"forms": [{form_name: {slot_name: [{"type": "invalid"}]}}]}
    )

    with pytest.raises(ValueError):
        form.extract_requested_slot(tracker, domain)
Beispiel #19
0
    def dump_domain(self,
                    filename: Optional[Text] = None,
                    project_id: Text = config.project_name):
        """Dump domain to `filename` in yml format."""
        domain = self.get_domain(project_id)

        if not domain:
            return

        if not filename:
            filename = domain.get("path") or config.default_domain_path

        cleaned_domain = RasaDomain.from_dict(domain).cleaned_domain()
        domain_path = utils.get_project_directory() / filename
        dump_yaml_to_file(domain_path, cleaned_domain)
    async def get_stories(request):
        from rasa.core.domain import Domain

        text_query = rasa_x_utils.default_arg(request, "q", None)
        fields = rasa_x_utils.fields_arg(request,
                                         {"name", "annotation.user", "id"})
        id_query = rasa_x_utils.list_arg(request, "id")

        distinct = rasa_x_utils.bool_arg(request, "distinct", True)
        stories = _story_service(request).fetch_stories(text_query,
                                                        fields,
                                                        id_query=id_query,
                                                        distinct=distinct)

        content_type = request.headers.get("Accept")
        if content_type == "text/vnd.graphviz":
            project_id = rasa_x_utils.default_arg(request, "project_id",
                                                  config.project_name)
            domain_dict = _domain_service(request).get_or_create_domain(
                project_id)
            domain = Domain.from_dict(domain_dict)

            visualization = await _story_service(request).visualize_stories(
                stories, domain)

            if visualization:
                return response.text(visualization)
            else:
                return rasa_x_utils.error(
                    HTTPStatus.NOT_ACCEPTABLE,
                    "VisualizationNotAvailable",
                    "Cannot produce a visualization for the requested stories",
                )
        elif content_type == "text/markdown":
            markdown = _story_service(request).get_stories_markdown(stories)
            return response.text(
                markdown,
                content_type="text/markdown",
                headers={
                    "Content-Disposition": "attachment;filename=stories.md"
                },
            )
        else:
            return response.json(stories,
                                 headers={"X-Total-Count": len(stories)})
Beispiel #21
0
async def test_fingerprinting_additional_action(project: Text):
    importer = _project_files(project)

    old_fingerprint = await model_fingerprint(importer)
    old_domain = await importer.get_domain()

    domain_with_new_action = old_domain.as_dict()
    domain_with_new_action[KEY_RESPONSES]["utter_new"] = [{"text": "hi"}]
    domain_with_new_action = Domain.from_dict(domain_with_new_action)

    importer.get_domain = asyncio.coroutine(lambda: domain_with_new_action)

    new_fingerprint = await model_fingerprint(importer)

    assert (old_fingerprint[FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY] !=
            new_fingerprint[FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY])
    assert old_fingerprint[FINGERPRINT_NLG_KEY] != new_fingerprint[
        FINGERPRINT_NLG_KEY]
Beispiel #22
0
async def test_ask_for_slot(domain: Dict, expected_action: Text,
                            monkeypatch: MonkeyPatch):
    slot_name = "sun"

    action_from_name = Mock(return_value=action.ActionListen())
    monkeypatch.setattr(action, action.action_from_name.__name__,
                        action_from_name)

    form = FormAction("my_form", None)
    await form._ask_for_slot(
        Domain.from_dict(domain),
        None,
        None,
        slot_name,
        DialogueStateTracker.from_events("dasd", []),
    )

    action_from_name.assert_called_once_with(expected_action, None, ANY)
Beispiel #23
0
async def test_trigger_slot_mapping_applies(
    trigger_slot_mapping: Dict, expected_value: Text
):
    form_name = "some_form"
    entity_name = "some_slot"
    slot_filled_by_trigger_mapping = "other_slot"
    form = FormAction(form_name, None)

    domain = Domain.from_dict(
        {
            "forms": [
                {
                    form_name: {
                        entity_name: [
                            {
                                "type": "from_entity",
                                "entity": entity_name,
                                "intent": "some_intent",
                            }
                        ],
                        slot_filled_by_trigger_mapping: [trigger_slot_mapping],
                    }
                }
            ]
        }
    )

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla",
                intent={"name": "greet", "confidence": 1.0},
                entities=[{"entity": entity_name, "value": "some_value"}],
            ),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    slot_values = form.extract_other_slots(tracker, domain)
    assert slot_values == {slot_filled_by_trigger_mapping: expected_value}
Beispiel #24
0
async def test_fingerprinting_changed_response_text(project: Text):
    importer = _project_files(project)

    old_fingerprint = await model_fingerprint(importer)
    old_domain = await importer.get_domain()

    # Change NLG content but keep actions the same
    domain_with_changed_nlg = old_domain.as_dict()
    domain_with_changed_nlg[KEY_RESPONSES]["utter_greet"].append(
        {"text": "hi"})
    domain_with_changed_nlg = Domain.from_dict(domain_with_changed_nlg)

    importer.get_domain = asyncio.coroutine(lambda: domain_with_changed_nlg)

    new_fingerprint = await model_fingerprint(importer)

    assert (old_fingerprint[FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY] ==
            new_fingerprint[FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY])
    assert old_fingerprint[FINGERPRINT_NLG_KEY] != new_fingerprint[
        FINGERPRINT_NLG_KEY]
Beispiel #25
0
def test_extract_other_slots_with_entity(
    some_other_slot_mapping: List[Dict[Text, Any]],
    some_slot_mapping: List[Dict[Text, Any]],
    entities: List[Dict[Text, Any]],
    intent: Text,
    expected_slot_values: Dict[Text, Text],
):
    """Test extraction of other not requested slots values from entities."""

    form_name = "some_form"
    form = FormAction(form_name, None)

    domain = Domain.from_dict(
        {
            "forms": [
                {
                    form_name: {
                        "some_other_slot": some_other_slot_mapping,
                        "some_slot": some_slot_mapping,
                    }
                }
            ]
        }
    )

    tracker = DialogueStateTracker.from_events(
        "default",
        [
            SlotSet(REQUESTED_SLOT, "some_slot"),
            UserUttered(
                "bla", intent={"name": intent, "confidence": 1.0}, entities=entities
            ),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    slot_values = form.extract_other_slots(tracker, domain)
    # check that the value was extracted for non requested slot
    assert slot_values == expected_slot_values
Beispiel #26
0
async def test_interactive_domain_persistence(mock_endpoint, tmpdir):
    # Test method interactive._write_domain_to_file

    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = rasa.utils.io.read_json_file(tracker_dump)

    events = tracker_json.get("events", [])

    domain_path = tmpdir.join("interactive_domain_save.yml").strpath

    url = f"{mock_endpoint.url}/domain"
    with aioresponses() as mocked:
        mocked.get(url, payload={})

        serialised_domain = await interactive.retrieve_domain(mock_endpoint)
        old_domain = Domain.from_dict(serialised_domain)

        await interactive._write_domain_to_file(domain_path, events,
                                                old_domain)

    saved_domain = rasa.utils.io.read_config_file(domain_path)

    for default_action in action.default_actions():
        assert default_action.name() not in saved_domain["actions"]
Beispiel #27
0
def test_forms_with_suited_policy(policy_config: Dict[Text, List[Text]]):
    # Doesn't raise
    Agent(
        domain=Domain.from_dict({"forms": ["restaurant_form"]}),
        policies=PolicyEnsemble.from_dict(policy_config),
    )
Beispiel #28
0
def test_add_default_intents(domain: Dict):
    domain = Domain.from_dict(domain)

    assert all(intent_name in domain.intents for intent_name in DEFAULT_INTENTS)