Exemplo n.º 1
0
async def test_without_additional_e2e_examples(tmp_path: Path):
    domain_path = tmp_path / "domain.yml"
    domain_path.write_text(Domain.empty().as_yaml())

    config_path = tmp_path / "config.yml"
    config_path.touch()

    existing = TrainingDataImporter.load_from_dict(
        {}, str(config_path), str(domain_path), []
    )

    stories = StoryGraph(
        [
            StoryStep(
                events=[
                    UserUttered(None, {"name": "greet_from_stories"}),
                    ActionExecuted("utter_greet_from_stories"),
                ]
            )
        ]
    )

    # Patch to return our test stories
    existing.get_stories = asyncio.coroutine(lambda *args: stories)

    importer = E2EImporter(existing)

    training_data = await importer.get_nlu_data()

    assert training_data.training_examples
    assert not training_data.is_empty()
    assert len(training_data.nlu_examples) == 0
Exemplo n.º 2
0
async def test_adding_e2e_actions_to_domain(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = os.path.join(project, DEFAULT_DOMAIN_PATH)
    default_data_path = os.path.join(project, DEFAULT_DATA_PATH)
    existing = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path,
                                                   [default_data_path])

    additional_actions = ["Hi Joey.", "it's sunny outside."]
    stories = StoryGraph([
        StoryStep(events=[
            UserUttered("greet_from_stories", {"name": "greet_from_stories"}),
            ActionExecuted("utter_greet_from_stories"),
        ]),
        StoryStep(events=[
            UserUttered("how are you doing?", {"name": "greet_from_stories"}),
            ActionExecuted(additional_actions[0],
                           action_text=additional_actions[0]),
            ActionExecuted(additional_actions[1],
                           action_text=additional_actions[1]),
            ActionExecuted(additional_actions[1],
                           action_text=additional_actions[1]),
        ]),
    ])

    # Patch to return our test stories
    existing.get_stories = asyncio.coroutine(lambda *args: stories)

    importer = E2EImporter(existing)
    domain = await importer.get_domain()

    assert all(action_name in domain.action_names
               for action_name in additional_actions)
Exemplo n.º 3
0
def _create_data_generator(
    resource_name: Text,
    agent: "Agent",
    max_stories: Optional[int] = None,
    use_conversation_test_files: bool = False,
) -> "TrainingDataGenerator":
    from rasa.shared.core.generator import TrainingDataGenerator
    from rasa.shared.constants import DEFAULT_DOMAIN_PATH
    from rasa.model import get_model_subdirectories

    core_model = None
    if agent.model_directory:
        core_model, _ = get_model_subdirectories(agent.model_directory)

    if core_model and os.path.exists(
            os.path.join(core_model, DEFAULT_DOMAIN_PATH)):
        domain_path = os.path.join(core_model, DEFAULT_DOMAIN_PATH)
    else:
        domain_path = None

    test_data_importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[resource_name], domain_path=domain_path)
    if use_conversation_test_files:
        story_graph = test_data_importer.get_conversation_tests()
    else:
        story_graph = test_data_importer.get_stories()

    return TrainingDataGenerator(
        story_graph,
        agent.domain,
        use_story_concatenation=False,
        augmentation_factor=0,
        tracker_limit=max_stories,
    )
Exemplo n.º 4
0
Arquivo: test.py Projeto: zoovu/rasa
def _create_data_generator(
    resource_name: Text,
    agent: "Agent",
    max_stories: Optional[int] = None,
    use_conversation_test_files: bool = False,
) -> "TrainingDataGenerator":
    from rasa.shared.core.generator import TrainingDataGenerator

    tmp_domain_path = Path(tempfile.mkdtemp()) / "domain.yaml"
    agent.domain.persist(tmp_domain_path)
    test_data_importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[resource_name], domain_path=str(tmp_domain_path)
    )
    if use_conversation_test_files:
        story_graph = test_data_importer.get_conversation_tests()
    else:
        story_graph = test_data_importer.get_stories()

    return TrainingDataGenerator(
        story_graph,
        agent.domain,
        use_story_concatenation=False,
        augmentation_factor=0,
        tracker_limit=max_stories,
    )
Exemplo n.º 5
0
def test_nlu_warn_of_competition_with_regex_extractor(
    monkeypatch: MonkeyPatch,
    component_types: List[Dict[Text, Text]],
    data_path: Text,
    should_warn: bool,
):
    importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[data_path])
    # there are no domain files for the above examples, so:
    monkeypatch.setattr(Domain, "check_missing_responses",
                        lambda *args, **kwargs: None)

    graph_schema = GraphSchema({
        f"{idx}": SchemaNode({}, component_type, "", "", {})
        for idx, component_type in enumerate(component_types)
    })
    validator = DefaultV1RecipeValidator(graph_schema)
    monkeypatch.setattr(validator, "_warn_if_some_training_data_is_unused",
                        lambda *args, **kwargs: None)

    if should_warn:
        with pytest.warns(
                UserWarning,
                match=(
                    f"You have an overlap between the "
                    f"'{RegexEntityExtractor.__name__}' and the statistical"),
        ):
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
        assert len(records) == 0
Exemplo n.º 6
0
def default_importer(project: Text) -> TrainingDataImporter:
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = os.path.join(project, DEFAULT_DOMAIN_PATH)
    default_data_path = os.path.join(project, DEFAULT_DATA_PATH)

    return TrainingDataImporter.load_from_dict({}, config_path, domain_path,
                                               [default_data_path])
Exemplo n.º 7
0
    def inner(
        train_schema: GraphSchema,
        cache: Optional[TrainingCache] = None,
        model_storage: Optional[ModelStorage] = None,
        path: Optional[Path] = None,
        force_retraining: bool = False,
    ) -> Path:
        if not path:
            path = tmp_path_factory.mktemp("model_storage_path")
        if not model_storage:
            model_storage = LocalModelStorage.create(path)
        if not cache:
            cache = local_cache_creator(path)

        graph_trainer = GraphTrainer(
            model_storage=model_storage, cache=cache, graph_runner_class=DaskGraphRunner
        )

        output_filename = path / "model.tar.gz"
        graph_trainer.train(
            GraphModelConfiguration(
                train_schema=train_schema,
                predict_schema=GraphSchema({}),
                language=None,
                core_target=None,
                nlu_target="nlu",
                training_type=TrainingType.BOTH,
            ),
            importer=TrainingDataImporter.load_from_dict(domain_path=str(domain_path)),
            output_filename=output_filename,
            force_retraining=force_retraining,
        )

        assert output_filename.is_file()
        return output_filename
Exemplo n.º 8
0
def test_read_conversation_tests(project: Text):
    importer = TrainingDataImporter.load_from_dict(training_data_paths=[
        str(Path(project) / DEFAULT_CONVERSATION_TEST_PATH)
    ])

    test_stories = importer.get_conversation_tests()
    assert len(test_stories.story_steps) == 7
Exemplo n.º 9
0
async def test_rasa_file_importer_with_invalid_domain(tmp_path: Path):
    config_file = tmp_path / "config.yml"
    config_file.write_text("")
    importer = TrainingDataImporter.load_from_dict({}, str(config_file), None, [])

    actual = await importer.get_domain()
    assert actual.as_dict() == Domain.empty().as_dict()
Exemplo n.º 10
0
async def test_nlu_data_domain_sync_with_retrieval_intents(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = "data/test_domains/default_retrieval_intents.yml"
    data_paths = [
        "data/test_nlu/default_retrieval_intents.md",
        "data/test_responses/default.md",
    ]
    base_data_importer = TrainingDataImporter.load_from_dict(
        {}, config_path, domain_path, data_paths
    )

    nlu_importer = NluDataImporter(base_data_importer)
    core_importer = CoreDataImporter(base_data_importer)

    importer = RetrievalModelsDataImporter(
        CombinedDataImporter([nlu_importer, core_importer])
    )
    domain = await importer.get_domain()
    nlu_data = await importer.get_nlu_data()

    assert domain.retrieval_intents == ["chitchat"]
    assert domain.intent_properties["chitchat"].get("is_retrieval_intent")
    assert domain.retrieval_intent_templates == nlu_data.responses
    assert domain.templates != nlu_data.responses
    assert "utter_chitchat" in domain.action_names
Exemplo n.º 11
0
def test_importer_fingerprint():
    importer = TrainingDataImporter.load_from_dict(training_data_paths=[
        "./data/test_nlu_no_responses/nlu_with_unicode.yml"
    ])

    fp1 = importer.fingerprint()
    fp2 = importer.fingerprint()
    assert fp1 != fp2
Exemplo n.º 12
0
async def test_nlu_comparison(
    tmp_path: Path, monkeypatch: MonkeyPatch, nlu_as_json_path: Text
):
    config = {
        "language": "en",
        "pipeline": [
            {"name": "WhitespaceTokenizer"},
            {"name": "KeywordIntentClassifier"},
            {"name": "RegexEntityExtractor"},
        ],
    }
    # the configs need to be at a different path, otherwise the results are
    # combined on the same dictionary key and cannot be plotted properly
    configs = [write_file_config(config).name, write_file_config(config).name]

    # mock training
    monkeypatch.setattr(Interpreter, "load", Mock(spec=RasaNLUInterpreter))
    monkeypatch.setattr(sys.modules["rasa.nlu"], "train", AsyncMock())

    monkeypatch.setattr(
        sys.modules["rasa.nlu.test"],
        "get_eval_data",
        Mock(return_value=(1, None, (None,),)),
    )
    monkeypatch.setattr(
        sys.modules["rasa.nlu.test"],
        "evaluate_intents",
        Mock(return_value={"f1_score": 1}),
    )

    output = str(tmp_path)
    test_data_importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[nlu_as_json_path]
    )
    test_data = test_data_importer.get_nlu_data()
    await compare_nlu_models(
        configs, test_data, output, runs=2, exclusion_percentages=[50, 80]
    )

    assert set(os.listdir(output)) == {
        "run_1",
        "run_2",
        "results.json",
        "nlu_model_comparison_graph.pdf",
    }

    run_1_path = os.path.join(output, "run_1")
    assert set(os.listdir(run_1_path)) == {"50%_exclusion", "80%_exclusion", "test.yml"}

    exclude_50_path = os.path.join(run_1_path, "50%_exclusion")
    modelnames = [os.path.splitext(os.path.basename(config))[0] for config in configs]

    modeloutputs = set(
        ["train"]
        + [f"{m}_report" for m in modelnames]
        + [f"{m}.tar.gz" for m in modelnames]
    )
    assert set(os.listdir(exclude_50_path)) == modeloutputs
Exemplo n.º 13
0
def test_importer_with_unicode_files():
    importer = TrainingDataImporter.load_from_dict(training_data_paths=[
        "./data/test_nlu_no_responses/nlu_with_unicode.yml"
    ])

    # None of these should raise
    nlu_data = importer.get_nlu_data()
    assert not nlu_data.is_empty()

    importer.get_stories()
    importer.get_domain()
Exemplo n.º 14
0
def test_nlu_data_domain_sync_responses(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = "data/test_domains/default.yml"
    data_paths = ["data/test_responses/responses_utter_rasa.yml"]

    importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path, data_paths)

    with pytest.warns(None):
        domain = importer.get_domain()

    # Responses were sync between "test_responses.yml" and the "domain.yml"
    assert "utter_rasa" in domain.responses.keys()
Exemplo n.º 15
0
def test_subintent_response_matches_with_action(project: Text):
    """Tests retrieval intent responses are matched correctly to actions."""
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = "data/test_domains/simple_retrieval_intent.yml"
    data_path = "data/test/simple_retrieval_intent_nlu.yml"
    importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path, data_path)

    domain = importer.get_domain()
    # Test retrieval intent response is matched correctly to actions
    # ie. utter_chitchat/faq response compatible with action utter_chitchat
    with pytest.warns(None) as record:
        domain.check_missing_responses()
    assert not record
Exemplo n.º 16
0
def test_core_warn_if_data_but_no_policy(monkeypatch: MonkeyPatch,
                                         policy_type: Optional[Type[Policy]]):

    importer = TrainingDataImporter.load_from_dict(
        domain_path="data/test_e2ebot/domain.yml",
        training_data_paths=[
            "data/test_e2ebot/data/nlu.yml",
            "data/test_e2ebot/data/stories.yml",
        ],
    )

    nodes = {
        "tokenizer": SchemaNode({}, WhitespaceTokenizer, "", "", {}),
        "nlu-component": SchemaNode({}, DIETClassifier, "", "", {}),
    }
    if policy_type is not None:
        nodes["some-policy"] = SchemaNode({}, policy_type, "", "", {})
    graph_schema = GraphSchema(nodes)

    validator = DefaultV1RecipeValidator(graph_schema)
    monkeypatch.setattr(
        validator,
        "_raise_if_a_rule_policy_is_incompatible_with_domain",
        lambda *args, **kwargs: None,
    )
    monkeypatch.setattr(validator, "_warn_if_no_rule_policy_is_contained",
                        lambda: None)
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )

    if policy_type is None:
        with pytest.warns(
                UserWarning,
                match="Found data for training policies but no policy"
        ) as records:
            validator.validate(importer)
        assert len(records) == 1
    else:
        with pytest.warns(
                UserWarning,
                match="Slot auto-fill has been removed in 3.0") as records:
            validator.validate(importer)
        assert all([
            warn.message.args[0].startswith("Slot auto-fill has been removed")
            for warn in records.list
        ])
Exemplo n.º 17
0
def test_load_from_dict(
    config: Dict, expected: List[Type["TrainingDataImporter"]], project: Text
):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = os.path.join(project, DEFAULT_DOMAIN_PATH)
    default_data_path = os.path.join(project, DEFAULT_DATA_PATH)
    actual = TrainingDataImporter.load_from_dict(
        config, config_path, domain_path, [default_data_path]
    )

    assert isinstance(actual, E2EImporter)
    assert isinstance(actual.importer, ResponsesSyncImporter)

    actual_importers = [i.__class__ for i in actual.importer._importer._importers]
    assert actual_importers == expected
Exemplo n.º 18
0
def test_spacy_preprocessor_process_training_data(
        spacy_nlp_component: SpacyNLP, spacy_model: SpacyModel):
    training_data = TrainingDataImporter.load_from_dict(training_data_paths=[
        "data/test_e2ebot/data/nlu.yml",
        "data/test_e2ebot/data/stories.yml",
    ]).get_nlu_data()

    spacy_nlp_component.process_training_data(training_data, spacy_model)

    for message in training_data.training_examples:
        for attr in DENSE_FEATURIZABLE_ATTRIBUTES:
            attr_text = message.data.get(attr)
            if attr_text:
                doc = message.data[SPACY_DOCS[attr]]
                assert isinstance(doc, spacy.tokens.doc.Doc)
                assert doc.text == attr_text.lower()
Exemplo n.º 19
0
def test_response_missing(project: Text):
    """Tests warning when response is missing."""
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = "data/test_domains/missing_chitchat_response.yml"
    data_path = "data/test/simple_retrieval_intent_nlu.yml"
    importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path, data_path)

    domain = importer.get_domain()
    with pytest.warns(UserWarning) as record:
        domain.check_missing_responses()

    assert (
        "Action 'utter_chitchat' is listed as a response action in the domain "
        "file, but there is no matching response defined. Please check your "
        "domain.") in record[0].message.args[0]
Exemplo n.º 20
0
Arquivo: test.py Projeto: tanbui/rasa
async def _create_data_generator(
    resource_name: Text,
    agent: "Agent",
    max_stories: Optional[int] = None,
    use_e2e: bool = False,
) -> "TrainingDataGenerator":
    from rasa.shared.core.generator import TrainingDataGenerator

    test_data_importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[resource_name])
    story_graph = await test_data_importer.get_stories(use_e2e=use_e2e)
    return TrainingDataGenerator(
        story_graph,
        agent.domain,
        use_story_concatenation=False,
        augmentation_factor=0,
        tracker_limit=max_stories,
    )
Exemplo n.º 21
0
def test_nlu_data_domain_sync_with_retrieval_intents(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = "data/test_domains/default_retrieval_intents.yml"
    data_paths = [
        "data/test/stories_default_retrieval_intents.yml",
        "data/test_responses/default.yml",
    ]
    importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path, data_paths)

    domain = importer.get_domain()
    nlu_data = importer.get_nlu_data()

    assert domain.retrieval_intents == ["chitchat"]
    assert domain.intent_properties["chitchat"].get("is_retrieval_intent")
    assert domain.retrieval_intent_responses == nlu_data.responses
    assert domain.responses != nlu_data.responses
    assert "utter_chitchat" in domain.action_names_or_texts
Exemplo n.º 22
0
async def test_nlu_data_domain_sync_responses(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = "data/test_domains/default.yml"
    data_paths = ["data/test_nlg/test_responses.yml"]

    base_data_importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                             domain_path,
                                                             data_paths)

    nlu_importer = NluDataImporter(base_data_importer)
    core_importer = CoreDataImporter(base_data_importer)

    importer = ResponsesSyncImporter(
        CombinedDataImporter([nlu_importer, core_importer]))
    with pytest.warns(None):
        domain = await importer.get_domain()

    # Responses were sync between "test_responses.yml" and the "domain.yml"
    assert "utter_rasa" in domain.templates.keys()
Exemplo n.º 23
0
def test_number_of_examples_per_intent_with_yaml(tmp_path: Path):
    domain_path = tmp_path / "domain.yml"
    domain_path.write_text(Domain.empty().as_yaml())

    config_path = tmp_path / "config.yml"
    config_path.touch()

    importer = TrainingDataImporter.load_from_dict(
        {},
        str(config_path),
        str(domain_path),
        [
            "data/test_number_nlu_examples/nlu.yml",
            "data/test_number_nlu_examples/stories.yml",
            "data/test_number_nlu_examples/rules.yml",
        ],
    )

    training_data = importer.get_nlu_data()
    assert training_data.intents == {"greet", "ask_weather"}
    assert training_data.number_of_examples_per_intent["greet"] == 2
    assert training_data.number_of_examples_per_intent["ask_weather"] == 3
Exemplo n.º 24
0
def test_core_warn_if_rule_data_unused(
    policy_type_not_consuming_rule_data: Type[Policy], ):

    importer = TrainingDataImporter.load_from_dict(
        domain_path="data/test_moodbot/domain.yml",
        training_data_paths=[
            "data/test_moodbot/data/nlu.yml",
            "data/test_moodbot/data/rules.yml",
        ],
    )

    graph_schema = GraphSchema({
        "policy":
        SchemaNode({}, policy_type_not_consuming_rule_data, "", "", {})
    })
    validator = DefaultV1RecipeValidator(graph_schema)

    with pytest.warns(
            UserWarning,
            match=("Found rule-based training data but no policy "
                   "supporting rule-based data."),
    ):
        validator.validate(importer)
Exemplo n.º 25
0
def test_core_warn_if_rule_data_missing(
        policy_type_consuming_rule_data: Type[Policy]):

    importer = TrainingDataImporter.load_from_dict(
        domain_path="data/test_e2ebot/domain.yml",
        training_data_paths=[
            "data/test_e2ebot/data/nlu.yml",
            "data/test_e2ebot/data/stories.yml",
        ],
    )

    graph_schema = GraphSchema({
        "policy":
        SchemaNode({}, policy_type_consuming_rule_data, "", "", {})
    })
    validator = DefaultV1RecipeValidator(graph_schema)

    with pytest.warns(
            UserWarning,
            match=("Found a rule-based policy in your configuration "
                   "but no rule-based training data."),
    ):
        validator.validate(importer)
Exemplo n.º 26
0
async def test_import_nlu_training_data_with_default_actions(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = os.path.join(project, DEFAULT_DOMAIN_PATH)
    default_data_path = os.path.join(project, DEFAULT_DATA_PATH)
    importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path,
                                                   [default_data_path])

    assert isinstance(importer, E2EImporter)
    importer_without_e2e = importer.importer

    # Check additional NLU training data from domain was added
    nlu_data = await importer.get_nlu_data()

    assert len(nlu_data.training_examples) > len(
        (await importer_without_e2e.get_nlu_data()).training_examples)

    extended_training_data = await importer.get_nlu_data()
    assert all(
        Message(data={
            ACTION_NAME: action_name,
            ACTION_TEXT: ""
        }) in extended_training_data.training_examples
        for action_name in rasa.shared.core.constants.DEFAULT_ACTION_NAMES)
Exemplo n.º 27
0
async def test_import_nlu_training_data_from_e2e_stories(project: Text):
    config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
    domain_path = os.path.join(project, DEFAULT_DOMAIN_PATH)
    default_data_path = os.path.join(project, DEFAULT_DATA_PATH)
    importer = TrainingDataImporter.load_from_dict({}, config_path,
                                                   domain_path,
                                                   [default_data_path])

    # The `E2EImporter` correctly wraps the underlying `CombinedDataImporter`
    assert isinstance(importer, E2EImporter)
    importer_without_e2e = importer.importer

    stories = StoryGraph([
        StoryStep(events=[
            SlotSet("some slot", "doesn't matter"),
            UserUttered("greet_from_stories", {"name": "greet_from_stories"}),
            ActionExecuted("utter_greet_from_stories"),
        ]),
        StoryStep(events=[
            UserUttered("how are you doing?"),
            ActionExecuted("utter_greet_from_stories", action_text="Hi Joey."),
        ]),
    ])

    # Patch to return our test stories
    importer_without_e2e.get_stories = asyncio.coroutine(lambda *args: stories)

    # The wrapping `E2EImporter` simply forwards these method calls
    assert (await importer_without_e2e.get_stories()).as_story_string() == (
        await importer.get_stories()).as_story_string()
    assert (await importer_without_e2e.get_config()) == (await
                                                         importer.get_config())

    # Check additional NLU training data from stories was added
    nlu_data = await importer.get_nlu_data()

    # The `E2EImporter` adds NLU training data based on our training stories
    assert len(nlu_data.training_examples) > len(
        (await importer_without_e2e.get_nlu_data()).training_examples)

    # Check if the NLU training data was added correctly from the story training data
    expected_additional_messages = [
        Message(data={
            TEXT: "greet_from_stories",
            INTENT_NAME: "greet_from_stories"
        }),
        Message(data={
            ACTION_NAME: "utter_greet_from_stories",
            ACTION_TEXT: ""
        }),
        Message(data={
            TEXT: "how are you doing?",
            INTENT_NAME: None
        }),
        Message(data={
            ACTION_NAME: "utter_greet_from_stories",
            ACTION_TEXT: "Hi Joey."
        }),
    ]

    assert all(m in nlu_data.training_examples
               for m in expected_additional_messages)
Exemplo n.º 28
0
def test_loader_loads_graph_runner(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
    tmp_path: Path,
    tmp_path_factory: TempPathFactory,
    domain_path: Path,
):
    graph_trainer = GraphTrainer(
        model_storage=default_model_storage,
        cache=temp_cache,
        graph_runner_class=DaskGraphRunner,
    )

    test_value = "test_value"

    train_schema = GraphSchema(
        {
            "train": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="train",
                constructor_name="create",
                config={"test_value": test_value},
                is_target=True,
            ),
            "load": SchemaNode(
                needs={"resource": "train"},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
            ),
        }
    )
    predict_schema = GraphSchema(
        {
            "load": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
                is_target=True,
                resource=Resource("train"),
            )
        }
    )

    output_filename = tmp_path / "model.tar.gz"

    importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[], domain_path=str(domain_path)
    )

    trained_at = datetime.utcnow()
    with freezegun.freeze_time(trained_at):
        model_metadata = graph_trainer.train(
            GraphModelConfiguration(
                train_schema=train_schema,
                predict_schema=predict_schema,
                training_type=TrainingType.BOTH,
                language=None,
                core_target=None,
                nlu_target=None,
            ),
            importer=importer,
            output_filename=output_filename,
        )

    assert isinstance(model_metadata, ModelMetadata)
    assert output_filename.is_file()

    loaded_model_storage_path = tmp_path_factory.mktemp("loaded model storage")

    model_metadata, loaded_predict_graph_runner = loader.load_predict_graph_runner(
        storage_path=loaded_model_storage_path,
        model_archive_path=output_filename,
        model_storage_class=LocalModelStorage,
        graph_runner_class=DaskGraphRunner,
    )

    assert loaded_predict_graph_runner.run() == {"load": test_value}

    assert model_metadata.predict_schema == predict_schema
    assert model_metadata.train_schema == train_schema
    assert model_metadata.model_id
    assert model_metadata.domain.as_dict() == Domain.from_path(domain_path).as_dict()
    assert model_metadata.rasa_open_source_version == rasa.__version__
    assert model_metadata.trained_at == trained_at
Exemplo n.º 29
0
def test_graph_trainer_returns_model_metadata(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
    tmp_path: Path,
    domain_path: Path,
):
    graph_trainer = GraphTrainer(
        model_storage=default_model_storage,
        cache=temp_cache,
        graph_runner_class=DaskGraphRunner,
    )

    test_value = "test_value"

    train_schema = GraphSchema(
        {
            "train": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="train",
                constructor_name="create",
                config={"test_value": test_value},
                is_target=True,
            ),
            "load": SchemaNode(
                needs={"resource": "train"},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
            ),
        }
    )
    predict_schema = GraphSchema(
        {
            "load": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
                is_target=True,
                resource=Resource("train"),
            )
        }
    )

    output_filename = tmp_path / "model.tar.gz"
    model_metadata = graph_trainer.train(
        GraphModelConfiguration(
            train_schema=train_schema,
            predict_schema=predict_schema,
            language=None,
            core_target=None,
            nlu_target="nlu",
            training_type=TrainingType.BOTH,
        ),
        importer=TrainingDataImporter.load_from_dict(domain_path=str(domain_path)),
        output_filename=output_filename,
    )
    assert model_metadata.model_id
    assert model_metadata.domain.as_dict() == Domain.from_path(domain_path).as_dict()
    assert model_metadata.train_schema == train_schema
    assert model_metadata.predict_schema == predict_schema
Exemplo n.º 30
0
async def run_nlu_test_async(
    config: Optional[Union[Text, List[Text]]],
    data_path: Text,
    models_path: Text,
    output_dir: Text,
    cross_validation: bool,
    percentages: List[int],
    runs: int,
    no_errors: bool,
    all_args: Dict[Text, Any],
) -> None:
    """Runs NLU tests.

    Args:
        all_args: all arguments gathered in a Dict so we can pass it as one argument
                  to other functions.
        config: it refers to the model configuration file. It can be a single file or
                a list of multiple files or a folder with multiple config files inside.
        data_path: path for the nlu data.
        models_path: path to a trained Rasa model.
        output_dir: output path for any files created during the evaluation.
        cross_validation: indicates if it should test the model using cross validation
                          or not.
        percentages: defines the exclusion percentage of the training data.
        runs: number of comparison runs to make.
        no_errors: indicates if incorrect predictions should be written to a file
                   or not.
    """
    from rasa.model_testing import (
        compare_nlu_models,
        perform_nlu_cross_validation,
        test_nlu,
    )

    data_path = rasa.cli.utils.get_validated_path(data_path, "nlu",
                                                  DEFAULT_DATA_PATH)
    test_data_importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[data_path])
    nlu_data = await test_data_importer.get_nlu_data()

    output = output_dir or DEFAULT_RESULTS_PATH
    all_args["errors"] = not no_errors
    rasa.shared.utils.io.create_directory(output)

    if config is not None and len(config) == 1:
        config = os.path.abspath(config[0])
        if os.path.isdir(config):
            config = rasa.shared.utils.io.list_files(config)

    if isinstance(config, list):
        logger.info(
            "Multiple configuration files specified, running nlu comparison mode."
        )

        config_files = []
        for file in config:
            try:
                validation_utils.validate_yaml_schema(
                    rasa.shared.utils.io.read_file(file),
                    CONFIG_SCHEMA_FILE,
                )
                config_files.append(file)
            except YamlException:
                rasa.shared.utils.io.raise_warning(
                    f"Ignoring file '{file}' as it is not a valid config file."
                )
                continue
        await compare_nlu_models(
            configs=config_files,
            test_data=nlu_data,
            output=output,
            runs=runs,
            exclusion_percentages=percentages,
        )
    elif cross_validation:
        logger.info("Test model using cross validation.")
        config = rasa.cli.utils.get_validated_path(config, "config",
                                                   DEFAULT_CONFIG_PATH)
        perform_nlu_cross_validation(config, nlu_data, output, all_args)
    else:
        model_path = rasa.cli.utils.get_validated_path(models_path, "model",
                                                       DEFAULT_MODELS_PATH)

        await test_nlu(model_path, data_path, output, all_args)