Example #1
0
def test_core_warn_if_no_rule_policy(monkeypatch: MonkeyPatch,
                                     policy_types: List[Type[Policy]],
                                     should_warn: bool):
    graph_schema = GraphSchema({
        f"{idx}": SchemaNode({}, policy_type, "", "", {})
        for idx, policy_type in enumerate(policy_types)
    })
    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_schema=graph_schema)
    monkeypatch.setattr(
        validator,
        "_raise_if_a_rule_policy_is_incompatible_with_domain",
        lambda *args, **kwargs: None,
    )
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )

    if should_warn:
        with pytest.warns(
                UserWarning,
                match=(f"'{RulePolicy.__name__}' is not "
                       "included in the model's "),
        ) as records:
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
        assert len(records) == 0
Example #2
0
def _test_validation_warnings_with_default_configs(
    training_data: TrainingData,
    component_types: List[Type],
    warnings: Optional[List[Text]] = None,
):
    dummy_importer = DummyImporter(training_data=training_data)
    graph_schema = GraphSchema({
        f"{idx}": SchemaNode(
            needs={},
            uses=component_type,
            constructor_name="",
            fn="",
            config=component_type.get_default_config(),
        )
        for idx, component_type in enumerate(component_types)
    })
    validator = DefaultV1RecipeValidator(graph_schema)
    if not warnings:
        with pytest.warns(None) as records:
            validator.validate(dummy_importer)
            assert len(records) == 0, [
                warning.message for warning in records.list
            ]
    else:
        with pytest.warns(None) as records:
            validator.validate(dummy_importer)
        assert len(records) == len(warnings), ", ".join(warning.message.args[0]
                                                        for warning in records)
        assert [
            re.match(warning.message.args[0], expected_warning)
            for warning, expected_warning in zip(records, warnings)
        ]
Example #3
0
def test_core_raise_if_a_rule_policy_is_incompatible_with_domain(
    monkeypatch: MonkeyPatch, ):

    domain = Domain.empty()

    num_instances = 2
    nodes = {}
    configs_for_rule_policies = []
    for feature_type in POLICY_CLASSSES:
        for idx in range(num_instances):
            unique_name = f"{feature_type.__name__}-{idx}"
            unique_config = {unique_name: None}
            nodes[unique_name] = SchemaNode({}, feature_type, "", "",
                                            unique_config)
        if feature_type == RulePolicy:
            configs_for_rule_policies.append(unique_config)

    mock = Mock()
    monkeypatch.setattr(RulePolicy, "raise_if_incompatible_with_domain", mock)

    validator = DefaultV1RecipeValidator(graph_schema=GraphSchema(nodes))
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )
    importer = DummyImporter()
    validator.validate(importer)

    # Note: this works because we validate nodes in insertion order
    mock.all_args_list == [{
        "config": config,
        "domain": domain
    } for config in configs_for_rule_policies]
Example #4
0
def test_nlu_raise_if_more_than_one_tokenizer(nodes: Dict[Text, SchemaNode]):
    graph_schema = GraphSchema(nodes)
    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_schema)
    with pytest.raises(InvalidConfigException,
                       match=".* more than one tokenizer"):
        validator.validate(importer)
Example #5
0
def test_nlu_warn_of_competition_with_regex_extractor(
    monkeypatch: MonkeyPatch,
    component_types: List[Dict[Text, Text]],
    data_path: Text,
    should_warn: bool,
):
    importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[data_path])
    # there are no domain files for the above examples, so:
    monkeypatch.setattr(Domain, "check_missing_responses",
                        lambda *args, **kwargs: None)

    graph_schema = GraphSchema({
        f"{idx}": SchemaNode({}, component_type, "", "", {})
        for idx, component_type in enumerate(component_types)
    })
    validator = DefaultV1RecipeValidator(graph_schema)
    monkeypatch.setattr(validator, "_warn_if_some_training_data_is_unused",
                        lambda *args, **kwargs: None)

    if should_warn:
        with pytest.warns(
                UserWarning,
                match=(
                    f"You have an overlap between the "
                    f"'{RegexEntityExtractor.__name__}' and the statistical"),
        ):
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
        assert len(records) == 0
Example #6
0
def test_nlu_warn_if_entity_synonyms_unused(components: List[GraphComponent],
                                            warns: bool):
    training_data = TrainingData(
        training_examples=[Message({TEXT: "hi"}),
                           Message({TEXT: "hi hi"})],
        entity_synonyms={"cat": "dog"},
    )
    assert training_data.entity_synonyms is not None
    importer = DummyImporter(training_data=training_data)

    graph_schema = GraphSchema({
        f"{idx}": SchemaNode({}, component, "", "", {})
        for idx, component in enumerate(components)
    })
    validator = DefaultV1RecipeValidator(graph_schema)

    if warns:
        match = (f"You have defined synonyms in your training data, but "
                 f"your NLU configuration does not include an "
                 f"'{EntitySynonymMapper.__name__}'. ")

        with pytest.warns(UserWarning, match=match):
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
            assert len(records) == 0
Example #7
0
def test_nlu_warn_if_lookup_table_and_crf_extractor_pattern_feature_mismatch(
        nodes: List[SchemaNode], warns: bool):
    training_data = TrainingData(
        training_examples=[Message({TEXT: "hi"}),
                           Message({TEXT: "hi hi"})],
        lookup_tables=[{
            "elements": "this-is-no-file-and-that-does-not-matter"
        }],
    )
    assert training_data.lookup_tables is not None
    importer = DummyImporter(training_data=training_data)

    graph_schema = GraphSchema(
        {f"{idx}": node
         for idx, node in enumerate(nodes)})
    validator = DefaultV1RecipeValidator(graph_schema)

    if warns:
        match = (
            f"You have defined training data consisting of lookup tables, "
            f"but your NLU configuration's "
            f"'{CRFEntityExtractor.__name__}' does not include the "
            f"'{CRFEntityExtractorOptions.PATTERN}' feature")

        with pytest.warns(UserWarning, match=match):
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
            assert len(records) == 0
Example #8
0
def test_core_raise_if_domain_contains_form_names_but_no_rule_policy_given(
        monkeypatch: MonkeyPatch, policy_types: List[Type[Policy]],
        should_raise: bool):
    domain_with_form = Domain.from_dict(
        {KEY_FORMS: {
            "some-form": {
                "required_slots": []
            }
        }})
    importer = DummyImporter(domain=domain_with_form)
    graph_schema = GraphSchema({
        "policy": SchemaNode({}, policy_type, "", "", {})
        for policy_type in policy_types
    })
    validator = DefaultV1RecipeValidator(graph_schema)
    monkeypatch.setattr(validator, "_validate_nlu",
                        lambda *args, **kwargs: None)
    monkeypatch.setattr(validator, "_warn_if_no_rule_policy_is_contained",
                        lambda *args, **kwargs: None)
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )
    if should_raise:
        with pytest.raises(
                InvalidDomain,
                match="You have defined a form action, but have not added the",
        ):
            validator.validate(importer)
    else:
        validator.validate(importer)
def test_no_warnings_with_default_project(tmp_path: Path):
    rasa.utils.common.copy_directory(Path("rasa/cli/initial_project"),
                                     tmp_path)

    importer = TrainingDataImporter.load_from_config(
        config_path=str(tmp_path / "config.yml"),
        domain_path=str(tmp_path / "domain.yml"),
        training_data_paths=[str(tmp_path / "data")],
    )

    config, _missing_keys, _configured_keys = DefaultV1Recipe.auto_configure(
        importer.get_config_file_for_auto_config(),
        importer.get_config(),
        TrainingType.END_TO_END,
    )
    graph_config = DefaultV1Recipe().graph_config_for_recipe(
        config, cli_parameters={}, training_type=TrainingType.END_TO_END)
    validator = DefaultV1RecipeValidator(graph_config.train_schema)

    with pytest.warns(
            UserWarning,
            match="Slot auto-fill has been removed in 3.0") as records:
        validator.validate(importer)
    assert all([
        warn.message.args[0].startswith("Slot auto-fill has been removed")
        for warn in records.list
    ])
Example #10
0
def test_nlu_training_data_validation():
    importer = DummyImporter(
        training_data=TrainingData([Message({
            TEXT: "some text",
            INTENT: ""
        })]))
    nlu_validator = DefaultV1RecipeValidator(GraphSchema({}))

    with pytest.warns(UserWarning, match="Found empty intent"):
        nlu_validator.validate(importer)
Example #11
0
def test_nlu_do_not_raise_if_trainable_tokenizer():
    config = rasa.shared.utils.io.read_yaml_file(
        "data/test_config/config_pretrained_embeddings_mitie_zh.yml")
    graph_config = DefaultV1Recipe().graph_config_for_recipe(config,
                                                             cli_parameters={})

    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_config.train_schema)

    # Does not raise
    validator.validate(importer)
Example #12
0
def test_nlu_do_not_raise_if_two_tokenizers_with_end_to_end():
    config = rasa.shared.utils.io.read_yaml_file(
        "rasa/engine/recipes/config_files/default_config.yml")
    graph_config = DefaultV1Recipe().graph_config_for_recipe(
        config, cli_parameters={}, training_type=TrainingType.END_TO_END)

    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_config.train_schema)

    # Does not raise
    validator.validate(importer)
Example #13
0
def test_core_warn_if_data_but_no_policy(monkeypatch: MonkeyPatch,
                                         policy_type: Optional[Type[Policy]]):

    importer = TrainingDataImporter.load_from_dict(
        domain_path="data/test_e2ebot/domain.yml",
        training_data_paths=[
            "data/test_e2ebot/data/nlu.yml",
            "data/test_e2ebot/data/stories.yml",
        ],
    )

    nodes = {
        "tokenizer": SchemaNode({}, WhitespaceTokenizer, "", "", {}),
        "nlu-component": SchemaNode({}, DIETClassifier, "", "", {}),
    }
    if policy_type is not None:
        nodes["some-policy"] = SchemaNode({}, policy_type, "", "", {})
    graph_schema = GraphSchema(nodes)

    validator = DefaultV1RecipeValidator(graph_schema)
    monkeypatch.setattr(
        validator,
        "_raise_if_a_rule_policy_is_incompatible_with_domain",
        lambda *args, **kwargs: None,
    )
    monkeypatch.setattr(validator, "_warn_if_no_rule_policy_is_contained",
                        lambda: None)
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )

    if policy_type is None:
        with pytest.warns(
                UserWarning,
                match="Found data for training policies but no policy"
        ) as records:
            validator.validate(importer)
        assert len(records) == 1
    else:
        with pytest.warns(
                UserWarning,
                match="Slot auto-fill has been removed in 3.0") as records:
            validator.validate(importer)
        assert all([
            warn.message.args[0].startswith("Slot auto-fill has been removed")
            for warn in records.list
        ])
Example #14
0
def test_nlu_warn_of_competing_extractors(
        component_types: List[Type[GraphComponent]], should_warn: bool):
    graph_schema = GraphSchema({
        f"{idx}": SchemaNode({}, component_type, "", "", {})
        for idx, component_type in enumerate(component_types)
    })
    importer = DummyImporter()
    nlu_validator = DefaultV1RecipeValidator(graph_schema)
    if should_warn:
        with pytest.warns(UserWarning,
                          match=".*defined multiple entity extractors"):
            nlu_validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            nlu_validator.validate(importer)
        assert len(records) == 0
Example #15
0
def test_nlu_raise_if_featurizers_are_not_compatible(
    component_types_and_configs: List[Tuple[Type[GraphComponent],
                                            Dict[Text, Any], Text]],
    should_raise: bool,
):
    graph_schema = GraphSchema({
        f"{node_name}": SchemaNode({}, component_type, "", fn, config)
        for (node_name, component_type, config,
             fn) in component_types_and_configs
    })
    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_schema)
    if should_raise:
        with pytest.raises(InvalidConfigException):
            validator.validate(importer)
    else:
        validator.validate(importer)
Example #16
0
def test_core_warn_if_policy_priorities_are_not_unique(
    monkeypatch: MonkeyPatch,
    policy_types: Set[Type[Policy]],
    num_duplicates: bool,
    priority: int,
):

    assert (
        len(policy_types) >= priority + num_duplicates
    ), f"This tests needs at least {priority+num_duplicates} many types."

    # start with a schema where node i has priority i
    nodes = {
        f"{idx}": SchemaNode("", policy_type, "", "", {"priority": idx})
        for idx, policy_type in enumerate(policy_types)
    }

    # give nodes p+1, .., p+num_duplicates-1 priority "priority"
    for idx in range(num_duplicates):
        nodes[f"{priority+idx+1}"].config["priority"] = priority

    validator = DefaultV1RecipeValidator(graph_schema=GraphSchema(nodes))
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )

    importer = DummyImporter()

    if num_duplicates > 0:
        duplicates = [
            node.uses for idx_str, node in nodes.items()
            if priority <= int(idx_str) <= priority + num_duplicates
        ]
        expected_message = (
            f"Found policies {_types_to_str(duplicates)} with same priority {priority} "
        )
        expected_message = re.escape(expected_message)
        with pytest.warns(UserWarning, match=expected_message):
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
        assert len(records) == 0
Example #17
0
def test_core_raise_if_policy_has_no_priority():
    class PolicyWithoutPriority(Policy, GraphComponent):
        def __init__(
            self,
            config: Dict[Text, Any],
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
        ) -> None:
            super().__init__(config, model_storage, resource,
                             execution_context)

    nodes = {"policy": SchemaNode("", PolicyWithoutPriority, "", "", {})}
    graph_schema = GraphSchema(nodes)
    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_schema)
    with pytest.raises(InvalidConfigException,
                       match="Every policy must have a priority value"):
        validator.validate(importer)
Example #18
0
def test_generate_graphs(
    config_path: Text,
    expected_train_schema_path: Text,
    expected_predict_schema_path: Text,
    training_type: TrainingType,
    is_finetuning: bool,
):
    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        expected_train_schema_path)
    expected_train_schema = GraphSchema.from_dict(expected_schema_as_dict)

    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        expected_predict_schema_path)
    expected_predict_schema = GraphSchema.from_dict(expected_schema_as_dict)

    config = rasa.shared.utils.io.read_yaml_file(config_path)

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(config, {},
                                                  training_type=training_type,
                                                  is_finetuning=is_finetuning)

    train_schema = model_config.train_schema
    for node_name, node in expected_train_schema.nodes.items():
        assert train_schema.nodes[node_name] == node

    assert train_schema == expected_train_schema

    default_v1_validator = DefaultV1RecipeValidator(train_schema)
    importer = RasaFileImporter()
    # does not raise
    default_v1_validator.validate(importer)

    predict_schema = model_config.predict_schema
    for node_name, node in expected_predict_schema.nodes.items():
        assert predict_schema.nodes[node_name] == node

    assert predict_schema == expected_predict_schema

    rasa.engine.validation.validate(model_config)
Example #19
0
def test_no_warnings_with_default_project(tmp_path: Path):
    rasa.utils.common.copy_directory(Path("rasa/cli/initial_project"),
                                     tmp_path)

    importer = TrainingDataImporter.load_from_config(
        config_path=str(tmp_path / "config.yml"),
        domain_path=str(tmp_path / "domain.yml"),
        training_data_paths=[str(tmp_path / "data")],
    )

    config, _missing_keys, _configured_keys = DefaultV1Recipe.auto_configure(
        importer.get_config_file_for_auto_config(),
        importer.get_config(),
        TrainingType.END_TO_END,
    )
    graph_config = DefaultV1Recipe().graph_config_for_recipe(
        config, cli_parameters={}, training_type=TrainingType.END_TO_END)
    validator = DefaultV1RecipeValidator(graph_config.train_schema)

    with pytest.warns(None) as records:
        validator.validate(importer)
    assert len(records) == 0
Example #20
0
def test_core_warn_if_rule_data_unused(
    policy_type_not_consuming_rule_data: Type[Policy], ):

    importer = TrainingDataImporter.load_from_dict(
        domain_path="data/test_moodbot/domain.yml",
        training_data_paths=[
            "data/test_moodbot/data/nlu.yml",
            "data/test_moodbot/data/rules.yml",
        ],
    )

    graph_schema = GraphSchema({
        "policy":
        SchemaNode({}, policy_type_not_consuming_rule_data, "", "", {})
    })
    validator = DefaultV1RecipeValidator(graph_schema)

    with pytest.warns(
            UserWarning,
            match=("Found rule-based training data but no policy "
                   "supporting rule-based data."),
    ):
        validator.validate(importer)
Example #21
0
def test_core_warn_if_rule_data_missing(
        policy_type_consuming_rule_data: Type[Policy]):

    importer = TrainingDataImporter.load_from_dict(
        domain_path="data/test_e2ebot/domain.yml",
        training_data_paths=[
            "data/test_e2ebot/data/nlu.yml",
            "data/test_e2ebot/data/stories.yml",
        ],
    )

    graph_schema = GraphSchema({
        "policy":
        SchemaNode({}, policy_type_consuming_rule_data, "", "", {})
    })
    validator = DefaultV1RecipeValidator(graph_schema)

    with pytest.warns(
            UserWarning,
            match=("Found a rule-based policy in your configuration "
                   "but no rule-based training data."),
    ):
        validator.validate(importer)