Exemple #1
0
async def test_adding_e2e_actions_to_domain(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = os.path.join(project, DEFAULT_DOMAIN_PATH)
    default_data_path = os.path.join(project, DEFAULT_DATA_PATH)
    existing = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path,
                                                   [default_data_path])

    additional_actions = ["Hi Joey.", "it's sunny outside."]
    stories = StoryGraph([
        StoryStep(events=[
            UserUttered("greet_from_stories", {"name": "greet_from_stories"}),
            ActionExecuted("utter_greet_from_stories"),
        ]),
        StoryStep(events=[
            UserUttered("how are you doing?", {"name": "greet_from_stories"}),
            ActionExecuted(additional_actions[0],
                           action_text=additional_actions[0]),
            ActionExecuted(additional_actions[1],
                           action_text=additional_actions[1]),
            ActionExecuted(additional_actions[1],
                           action_text=additional_actions[1]),
        ]),
    ])

    # Patch to return our test stories
    existing.get_stories = asyncio.coroutine(lambda *args: stories)

    importer = E2EImporter(existing)
    domain = await importer.get_domain()

    assert all(action_name in domain.action_names
               for action_name in additional_actions)
Exemple #2
0
async def test_without_additional_e2e_examples(tmp_path: Path):
    domain_path = tmp_path / "domain.yml"
    domain_path.write_text(Domain.empty().as_yaml())

    config_path = tmp_path / "config.yml"
    config_path.touch()

    existing = TrainingDataImporter.load_from_dict({}, str(config_path),
                                                   str(domain_path), [])

    stories = StoryGraph([
        StoryStep(events=[
            UserUttered("greet_from_stories", {"name": "greet_from_stories"}),
            ActionExecuted("utter_greet_from_stories"),
        ])
    ])

    # Patch to return our test stories
    existing.get_stories = asyncio.coroutine(lambda *args: stories)

    importer = E2EImporter(existing)

    training_data = await importer.get_nlu_data()

    assert training_data.training_examples
    assert training_data.is_empty()
    assert not training_data.without_empty_e2e_examples().training_examples
Exemple #3
0
async def test_import_nlu_training_data_with_default_actions(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = os.path.join(project, DEFAULT_DOMAIN_PATH)
    default_data_path = os.path.join(project, DEFAULT_DATA_PATH)
    importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path,
                                                   [default_data_path])

    assert isinstance(importer, E2EImporter)
    importer_without_e2e = importer.importer

    # Check additional NLU training data from domain was added
    nlu_data = await importer.get_nlu_data()

    assert len(nlu_data.training_examples) > len(
        (await importer_without_e2e.get_nlu_data()).training_examples)

    from rasa.core.actions import action

    extended_training_data = await importer.get_nlu_data()
    assert all(
        Message(data={
            ACTION_NAME: action_name,
            ACTION_TEXT: ""
        }) in extended_training_data.training_examples
        for action_name in action.default_action_names())
Exemple #4
0
async def test_rasa_file_importer_with_invalid_domain(tmp_path: Path):
    config_file = tmp_path / "config.yml"
    config_file.write_text("")
    importer = TrainingDataImporter.load_from_dict({}, str(config_file), None, [])

    actual = await importer.get_domain()
    assert actual.as_dict() == Domain.empty().as_dict()
Exemple #5
0
def test_load_from_dict(config: Dict,
                        expected: List[Type["TrainingDataImporter"]],
                        project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = os.path.join(project, DEFAULT_DOMAIN_PATH)
    default_data_path = os.path.join(project, DEFAULT_DATA_PATH)
    actual = TrainingDataImporter.load_from_dict(config, config_path,
                                                 domain_path,
                                                 [default_data_path])

    assert isinstance(actual, CombinedDataImporter)

    actual_importers = [i.__class__ for i in actual._importers]
    assert actual_importers == expected
Exemple #6
0
async def test_nlu_data_domain_sync_with_retrieval_intents(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = "data/test_domains/default_retrieval_intents.yml"
    data_paths = [
        "data/test_nlu/default_retrieval_intents.md",
        "data/test_responses/default.md",
    ]
    base_data_importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                             domain_path,
                                                             data_paths)

    nlu_importer = NluDataImporter(base_data_importer)
    core_importer = CoreDataImporter(base_data_importer)

    importer = RetrievalModelsDataImporter(
        CombinedDataImporter([nlu_importer, core_importer]))
    domain = await importer.get_domain()
    nlu_data = await importer.get_nlu_data()

    assert domain.retrieval_intents == ["chitchat"]
    assert domain.intent_properties["chitchat"].get("is_retrieval_intent")
    assert domain.templates == nlu_data.responses
    assert "utter_chitchat" in domain.action_names
Exemple #7
0
async def test_import_nlu_training_data_from_e2e_stories(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = os.path.join(project, DEFAULT_DOMAIN_PATH)
    default_data_path = os.path.join(project, DEFAULT_DATA_PATH)
    importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path,
                                                   [default_data_path])

    # The `E2EImporter` correctly wraps the underlying `CombinedDataImporter`
    assert isinstance(importer, E2EImporter)
    importer_without_e2e = importer.importer

    stories = StoryGraph([
        StoryStep(events=[
            SlotSet("some slot", "doesn't matter"),
            UserUttered("greet_from_stories", {"name": "greet_from_stories"}),
            ActionExecuted("utter_greet_from_stories"),
        ]),
        StoryStep(events=[
            UserUttered("how are you doing?"),
            ActionExecuted("utter_greet_from_stories", action_text="Hi Joey."),
        ]),
    ])

    # Patch to return our test stories
    importer_without_e2e.get_stories = asyncio.coroutine(lambda *args: stories)

    # The wrapping `E2EImporter` simply forwards these method calls
    assert (await importer_without_e2e.get_stories()).as_story_string() == (
        await importer.get_stories()).as_story_string()
    assert (await importer_without_e2e.get_config()) == (await
                                                         importer.get_config())

    # Check additional NLU training data from stories was added
    nlu_data = await importer.get_nlu_data()

    # The `E2EImporter` adds NLU training data based on our training stories
    assert len(nlu_data.training_examples) > len(
        (await importer_without_e2e.get_nlu_data()).training_examples)

    # Check if the NLU training data was added correctly from the story training data
    expected_additional_messages = [
        Message(data={
            TEXT: "greet_from_stories",
            INTENT_NAME: "greet_from_stories"
        }),
        Message(data={
            ACTION_NAME: "utter_greet_from_stories",
            ACTION_TEXT: ""
        }),
        Message(data={
            TEXT: "how are you doing?",
            INTENT_NAME: None
        }),
        Message(data={
            ACTION_NAME: "utter_greet_from_stories",
            ACTION_TEXT: "Hi Joey."
        }),
    ]

    assert all(m in nlu_data.training_examples
               for m in expected_additional_messages)