Exemple #1
0
def test_validation_with_core_target_wrong_type():
    graph_config = GraphSchema(
        {
            "A":
            SchemaNode(
                needs={},
                uses=TestNLUTarget,
                eager=True,
                constructor_name="create",
                fn="run",
                config={},
            ),
        }, )

    with pytest.raises(
            GraphSchemaValidationException,
            match="Core model's .* invalid return type",
    ):
        validation.validate(
            GraphModelConfiguration(
                train_schema=GraphSchema({}),
                predict_schema=graph_config,
                training_type=TrainingType.BOTH,
                language=None,
                core_target="A",
                nlu_target="A",
            ))
Exemple #2
0
def test_validation_with_missing_nlu_target():
    graph_config = GraphSchema(
        {
            "A": SchemaNode(
                needs={},
                uses=TestNLUTarget,
                eager=True,
                constructor_name="create",
                fn="run",
                config={},
            )
        }
    )

    with pytest.raises(
        GraphSchemaValidationException, match="no target for the 'nlu_target'"
    ):
        validation.validate(
            GraphModelConfiguration(
                train_schema=GraphSchema({}),
                predict_schema=graph_config,
                training_type=TrainingType.BOTH,
                language=None,
                core_target=None,
                nlu_target=None,
            )
        )
Exemple #3
0
def test_minimal_graph_schema():
    def test_schema_node(needs: Dict[Text, Text],
                         target: bool = False) -> SchemaNode:
        return SchemaNode(
            needs=needs,
            uses=None,
            fn="",
            constructor_name="",
            config={},
            is_target=target,
        )

    assert GraphSchema({
        "1": test_schema_node({"i": "3"}, True),
        "2": test_schema_node({"i": "3"}),
        "3": test_schema_node({"i": "4"}),
        "4": test_schema_node({}),
        "5": test_schema_node({"i": "6"}),
        "6": test_schema_node({}),
        "7": test_schema_node({}),
        "8": test_schema_node({"i": "9"}, True),
        "9": test_schema_node({}),
    }).minimal_graph_schema() == GraphSchema({
        "1":
        test_schema_node({"i": "3"}, True),
        "3":
        test_schema_node({"i": "4"}),
        "4":
        test_schema_node({}),
        "8":
        test_schema_node({"i": "9"}, True),
        "9":
        test_schema_node({}),
    })
def test_graph_trainer_returns_prediction_runner(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
    tmp_path: Path,
    domain_path: Path,
):
    graph_trainer = GraphTrainer(
        model_storage=default_model_storage,
        cache=temp_cache,
        graph_runner_class=DaskGraphRunner,
    )

    test_value = "test_value"

    train_schema = GraphSchema({
        "train":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="train",
            constructor_name="create",
            config={
                "test_value": test_value,
            },
            is_target=True,
        ),
        "load":
        SchemaNode(
            needs={"resource": "train"},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
        ),
    })
    predict_schema = GraphSchema({
        "load":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
            is_target=True,
            resource=Resource("train"),
        ),
    })

    output_filename = tmp_path / "model.tar.gz"
    predict_graph_runner = graph_trainer.train(
        train_schema=train_schema,
        predict_schema=predict_schema,
        domain_path=domain_path,
        output_filename=output_filename,
    )
    assert isinstance(predict_graph_runner, DaskGraphRunner)
    assert output_filename.is_file()
    assert predict_graph_runner.run() == {"load": test_value}
Exemple #5
0
def test_create_model_package_with_non_existing_dir(
        tmp_path: Path, default_model_storage: ModelStorage):
    path = tmp_path / "some_dir" / "another" / "model.tar.gz"
    default_model_storage.create_model_package(
        path,
        GraphModelConfiguration(GraphSchema({}), GraphSchema({}),
                                TrainingType.BOTH, None, None, "nlu"),
        Domain.empty(),
    )

    assert path.exists()
Exemple #6
0
def test_cycle(is_train_graph: bool):
    class MyTestComponent(TestComponentWithoutRun):
        def run(self, training_data: TrainingData) -> TrainingData:
            pass

    train_schema = GraphSchema({})
    predict_schema = DEFAULT_PREDICT_SCHEMA

    schema = GraphSchema({
        "A":
        SchemaNode(
            needs={"training_data": "B"},
            uses=MyTestComponent,
            eager=True,
            constructor_name="create",
            fn="run",
            is_target=True,
            config={},
        ),
        "B":
        SchemaNode(
            needs={"training_data": "C"},
            uses=MyTestComponent,
            eager=True,
            constructor_name="create",
            fn="run",
            config={},
        ),
        "C":
        SchemaNode(
            needs={"training_data": "A"},
            uses=MyTestComponent,
            eager=True,
            constructor_name="create",
            fn="run",
            config={},
        ),
    })

    if is_train_graph:
        train_schema = schema
    else:
        predict_schema = schema

    with pytest.raises(GraphSchemaValidationException, match="Cycles"):
        validation.validate(
            GraphModelConfiguration(
                train_schema=train_schema,
                predict_schema=predict_schema,
                training_type=TrainingType.BOTH,
                language=None,
                core_target=None,
                nlu_target="nlu_target",
            ))
Exemple #7
0
def create_test_schema(
    uses:
    Type,  # The unspecified type is on purpose to enable testing of invalid cases
    constructor_name: Text = "create",
    run_fn: Text = "run",
    needs: Optional[Dict[Text, Text]] = None,
    eager: bool = True,
    parent: Optional[Type[GraphComponent]] = None,
    language: Optional[Text] = None,
    is_train_graph: bool = True,
) -> GraphModelConfiguration:

    parent_node = {}
    if parent:
        parent_node = {
            "parent":
            SchemaNode(needs={},
                       uses=parent,
                       constructor_name="create",
                       fn="run",
                       config={})
        }

    train_schema = GraphSchema({})
    predict_schema = DEFAULT_PREDICT_SCHEMA
    # noinspection PyTypeChecker
    schema = GraphSchema(
        {
            "my_node":
            SchemaNode(
                needs=needs or {},
                uses=uses,
                eager=eager,
                constructor_name=constructor_name,
                fn=run_fn,
                config={},
            ),
            **DEFAULT_PREDICT_SCHEMA.nodes,
            **parent_node,
        }, )

    if is_train_graph:
        train_schema = schema
    else:
        predict_schema = schema

    return GraphModelConfiguration(
        train_schema=train_schema,
        predict_schema=predict_schema,
        training_type=TrainingType.BOTH,
        core_target=None,
        nlu_target="nlu_target",
        language=language,
    )
Exemple #8
0
def test_create_package_with_non_existing_parent(tmp_path: Path):
    storage = LocalModelStorage.create(tmp_path)
    model_file = tmp_path / "new" / "sub" / "dir" / "file.tar.gz"

    storage.create_model_package(
        model_file,
        GraphModelConfiguration(GraphSchema({}), GraphSchema({}),
                                TrainingType.BOTH, None, None, "nlu"),
        Domain.empty(),
    )

    assert model_file.is_file()
Exemple #9
0
    def graph_config_for_recipe(
        self,
        config: Dict,
        cli_parameters: Dict[Text, Any],
        training_type: TrainingType = TrainingType.BOTH,
        is_finetuning: bool = False,
    ) -> GraphModelConfiguration:
        """Converts the default config to graphs (see interface for full docstring)."""
        self._use_core = (
            bool(config.get("policies")) and not training_type == TrainingType.NLU
        )
        self._use_nlu = (
            bool(config.get("pipeline")) and not training_type == TrainingType.CORE
        )

        if not self._use_nlu and training_type == TrainingType.NLU:
            raise InvalidConfigException(
                "Can't train an NLU model without a specified pipeline. Please make "
                "sure to specify a valid pipeline in your configuration."
            )

        if not self._use_core and training_type == TrainingType.CORE:
            raise InvalidConfigException(
                "Can't train an Core model without policies. Please make "
                "sure to specify a valid policy in your configuration."
            )

        self._use_end_to_end = (
            self._use_nlu
            and self._use_core
            and training_type == TrainingType.END_TO_END
        )

        self._is_finetuning = is_finetuning

        train_nodes, preprocessors = self._create_train_nodes(config, cli_parameters)
        predict_nodes = self._create_predict_nodes(config, preprocessors, train_nodes)

        core_target = "select_prediction" if self._use_core else None

        from rasa.nlu.classifiers.regex_message_handler import RegexMessageHandler

        return GraphModelConfiguration(
            train_schema=GraphSchema(train_nodes),
            predict_schema=GraphSchema(predict_nodes),
            training_type=training_type,
            language=config.get("language"),
            core_target=core_target,
            nlu_target=f"run_{RegexMessageHandler.__name__}",
        )
Exemple #10
0
def test_validation_with_core_target_used_by_other_node():
    class CoreTargetConsumer(TestComponentWithoutRun):
        def run(self,
                core_target_output: PolicyPrediction) -> PolicyPrediction:
            pass

    graph_config = GraphSchema(
        {
            "A":
            SchemaNode(
                needs={},
                uses=TestNLUTarget,
                eager=True,
                constructor_name="create",
                fn="run",
                config={},
            ),
            "B":
            SchemaNode(
                needs={},
                uses=TestCoreTarget,
                eager=True,
                constructor_name="create",
                fn="run",
                config={},
            ),
            "C":
            SchemaNode(
                needs={"core_target_output": "B"},
                uses=CoreTargetConsumer,
                eager=True,
                constructor_name="create",
                fn="run",
                config={},
            ),
        }, )

    with pytest.raises(GraphSchemaValidationException,
                       match="uses the Core target 'B' as input"):
        validation.validate(
            GraphModelConfiguration(
                train_schema=GraphSchema({}),
                predict_schema=graph_config,
                training_type=TrainingType.BOTH,
                language=None,
                core_target="B",
                nlu_target="A",
            ))
Exemple #11
0
def test_validate_validates_required_components(
    test_case: List[RequiredComponentsTestCase],
    is_train_graph: bool,
    test_subclass: bool,
):
    train_schema = GraphSchema({})
    predict_schema = DEFAULT_PREDICT_SCHEMA
    graph_schema = _create_graph_schema_from_requirements(
        node_needs_requires=test_case.node_needs_requires_tuples,
        targets=test_case.targets,
        use_subclass=test_subclass,
    )

    if is_train_graph:
        train_schema = graph_schema
    else:
        predict_schema = graph_schema
    graph_config = GraphModelConfiguration(train_schema, predict_schema,
                                           TrainingType.BOTH, None, None,
                                           "nlu_target")

    num_unmet = test_case.num_unmet_requirements
    if num_unmet == 0:
        validation.validate(graph_config)
    else:
        message = f"{num_unmet} components are missing"
        with pytest.raises(GraphSchemaValidationException, match=message):
            validation.validate(graph_config)
Exemple #12
0
    def inner(
        train_schema: GraphSchema,
        cache: Optional[TrainingCache] = None,
        model_storage: Optional[ModelStorage] = None,
        path: Optional[Path] = None,
        force_retraining: bool = False,
    ) -> Path:
        if not path:
            path = tmp_path_factory.mktemp("model_storage_path")
        if not model_storage:
            model_storage = LocalModelStorage.create(path)
        if not cache:
            cache = local_cache_creator(path)

        graph_trainer = GraphTrainer(
            model_storage=model_storage, cache=cache, graph_runner_class=DaskGraphRunner
        )

        output_filename = path / "model.tar.gz"
        graph_trainer.train(
            GraphModelConfiguration(
                train_schema=train_schema,
                predict_schema=GraphSchema({}),
                language=None,
                core_target=None,
                nlu_target="nlu",
                training_type=TrainingType.BOTH,
            ),
            importer=TrainingDataImporter.load_from_dict(domain_path=str(domain_path)),
            output_filename=output_filename,
            force_retraining=force_retraining,
        )

        assert output_filename.is_file()
        return output_filename
Exemple #13
0
    def inner(
        train_schema: GraphSchema,
        cache: Optional[TrainingCache] = None,
        model_storage: Optional[ModelStorage] = None,
        path: Optional[Path] = None,
    ) -> Path:
        if not path:
            path = tmp_path_factory.mktemp("model_storage_path")
        if not model_storage:
            model_storage = LocalModelStorage.create(path)
        if not cache:
            cache = local_cache_creator(path)

        graph_trainer = GraphTrainer(
            model_storage=model_storage,
            cache=cache,
            graph_runner_class=DaskGraphRunner,
        )

        output_filename = path / "model.tar.gz"
        graph_trainer.train(
            train_schema=train_schema,
            predict_schema=GraphSchema({}),
            domain_path=domain_path,
            output_filename=output_filename,
        )

        assert output_filename.is_file()
        return output_filename
Exemple #14
0
def test_core_warn_if_no_rule_policy(monkeypatch: MonkeyPatch,
                                     policy_types: List[Type[Policy]],
                                     should_warn: bool):
    graph_schema = GraphSchema({
        f"{idx}": SchemaNode({}, policy_type, "", "", {})
        for idx, policy_type in enumerate(policy_types)
    })
    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_schema=graph_schema)
    monkeypatch.setattr(
        validator,
        "_raise_if_a_rule_policy_is_incompatible_with_domain",
        lambda *args, **kwargs: None,
    )
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )

    if should_warn:
        with pytest.warns(
                UserWarning,
                match=(f"'{RulePolicy.__name__}' is not "
                       "included in the model's "),
        ) as records:
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
        assert len(records) == 0
Exemple #15
0
def test_loading_from_resource_eager(default_model_storage: ModelStorage):
    previous_resource = Resource("previous resource")
    test_value = {"test": "test value"}

    # Pretend resource persisted itself before
    with default_model_storage.write_to(previous_resource) as directory:
        rasa.shared.utils.io.dump_obj_as_json_to_file(directory / "test.json",
                                                      test_value)

    node_name = "some_name"
    node = GraphNode(
        node_name=node_name,
        component_class=PersistableTestComponent,
        constructor_name="load",
        component_config={},
        fn_name="run_inference",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        # The `GraphComponent` should load from this resource
        resource=previous_resource,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    actual_node_name, value = node()

    assert actual_node_name == node_name
    assert value == test_value
Exemple #16
0
def test_target_override(eager: bool, default_model_storage: ModelStorage):
    graph_schema = GraphSchema(
        {
            "add": SchemaNode(
                needs={"i1": "first_input", "i2": "second_input"},
                uses=AddInputs,
                fn="add",
                constructor_name="create",
                config={},
                eager=eager,
            ),
            "subtract_2": SchemaNode(
                needs={"i": "add"},
                uses=SubtractByX,
                fn="subtract_x",
                constructor_name="create",
                config={"x": 3},
                eager=eager,
                is_target=True,
            ),
        }
    )

    execution_context = ExecutionContext(graph_schema=graph_schema, model_id="1")

    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=execution_context,
    )
    results = runner.run(inputs={"first_input": 3, "second_input": 4}, targets=["add"])
    assert results == {"add": 7}
Exemple #17
0
def test_nlu_warn_if_entity_synonyms_unused(components: List[GraphComponent],
                                            warns: bool):
    training_data = TrainingData(
        training_examples=[Message({TEXT: "hi"}),
                           Message({TEXT: "hi hi"})],
        entity_synonyms={"cat": "dog"},
    )
    assert training_data.entity_synonyms is not None
    importer = DummyImporter(training_data=training_data)

    graph_schema = GraphSchema({
        f"{idx}": SchemaNode({}, component, "", "", {})
        for idx, component in enumerate(components)
    })
    validator = DefaultV1RecipeValidator(graph_schema)

    if warns:
        match = (f"You have defined synonyms in your training data, but "
                 f"your NLU configuration does not include an "
                 f"'{EntitySynonymMapper.__name__}'. ")

        with pytest.warns(UserWarning, match=match):
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
            assert len(records) == 0
Exemple #18
0
def test_unused_node(default_model_storage: ModelStorage):
    graph_schema = GraphSchema({
        "provide":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
            is_target=True,
        ),
        # This node will not fail as it will be pruned because it is not a target
        # or a target's ancestor.
        "assert_false":
        SchemaNode(
            needs={"i": "input"},
            uses=AssertComponent,
            fn="run_assert",
            constructor_name="create",
            config={"value_to_assert": "some_value"},
        ),
    })
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=graph_schema,
                                           model_id="1"),
    )
    results = runner.run(inputs={"input": "some_other_value"})
    assert results == {"provide": 1}
Exemple #19
0
def test_nlu_warn_if_lookup_table_and_crf_extractor_pattern_feature_mismatch(
        nodes: List[SchemaNode], warns: bool):
    training_data = TrainingData(
        training_examples=[Message({TEXT: "hi"}),
                           Message({TEXT: "hi hi"})],
        lookup_tables=[{
            "elements": "this-is-no-file-and-that-does-not-matter"
        }],
    )
    assert training_data.lookup_tables is not None
    importer = DummyImporter(training_data=training_data)

    graph_schema = GraphSchema(
        {f"{idx}": node
         for idx, node in enumerate(nodes)})
    validator = DefaultV1RecipeValidator(graph_schema)

    if warns:
        match = (
            f"You have defined training data consisting of lookup tables, "
            f"but your NLU configuration's "
            f"'{CRFEntityExtractor.__name__}' does not include the "
            f"'{CRFEntityExtractorOptions.PATTERN}' feature")

        with pytest.warns(UserWarning, match=match):
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
            assert len(records) == 0
Exemple #20
0
def test_serialize_graph_schema(tmp_path: Path):
    graph_schema = GraphSchema(
        {
            "train": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="train",
                constructor_name="create",
                config={"some_config": 123455, "some more config": [{"nested": "hi"}]},
            ),
            "load": SchemaNode(
                needs={"resource": "train"},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
                is_target=True,
                resource=Resource("test resource"),
            ),
        }
    )

    serialized = graph_schema.as_dict()

    # Dump it to make sure it's actually serializable
    file_path = tmp_path / "my_graph.yml"
    rasa.shared.utils.io.write_yaml(serialized, file_path)

    serialized_graph_schema_from_file = rasa.shared.utils.io.read_yaml_file(file_path)
    graph_schema_from_file = GraphSchema.from_dict(serialized_graph_schema_from_file)

    assert graph_schema_from_file == graph_schema
Exemple #21
0
def test_core_raise_if_domain_contains_form_names_but_no_rule_policy_given(
        monkeypatch: MonkeyPatch, policy_types: List[Type[Policy]],
        should_raise: bool):
    domain_with_form = Domain.from_dict(
        {KEY_FORMS: {
            "some-form": {
                "required_slots": []
            }
        }})
    importer = DummyImporter(domain=domain_with_form)
    graph_schema = GraphSchema({
        "policy": SchemaNode({}, policy_type, "", "", {})
        for policy_type in policy_types
    })
    validator = DefaultV1RecipeValidator(graph_schema)
    monkeypatch.setattr(validator, "_validate_nlu",
                        lambda *args, **kwargs: None)
    monkeypatch.setattr(validator, "_warn_if_no_rule_policy_is_contained",
                        lambda *args, **kwargs: None)
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )
    if should_raise:
        with pytest.raises(
                InvalidDomain,
                match="You have defined a form action, but have not added the",
        ):
            validator.validate(importer)
    else:
        validator.validate(importer)
Exemple #22
0
def test_core_raise_if_a_rule_policy_is_incompatible_with_domain(
    monkeypatch: MonkeyPatch, ):

    domain = Domain.empty()

    num_instances = 2
    nodes = {}
    configs_for_rule_policies = []
    for feature_type in POLICY_CLASSSES:
        for idx in range(num_instances):
            unique_name = f"{feature_type.__name__}-{idx}"
            unique_config = {unique_name: None}
            nodes[unique_name] = SchemaNode({}, feature_type, "", "",
                                            unique_config)
        if feature_type == RulePolicy:
            configs_for_rule_policies.append(unique_config)

    mock = Mock()
    monkeypatch.setattr(RulePolicy, "raise_if_incompatible_with_domain", mock)

    validator = DefaultV1RecipeValidator(graph_schema=GraphSchema(nodes))
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )
    importer = DummyImporter()
    validator.validate(importer)

    # Note: this works because we validate nodes in insertion order
    mock.all_args_list == [{
        "config": config,
        "domain": domain
    } for config in configs_for_rule_policies]
Exemple #23
0
def create_test_schema(
    uses: Type,  # The unspecified type is on purpose to enable testing of invalid cases
    constructor_name: Text = "create",
    run_fn: Text = "run",
    needs: Optional[Dict[Text, Text]] = None,
    eager: bool = True,
    parent: Optional[Type[GraphComponent]] = None,
) -> GraphSchema:
    parent_node = {}
    if parent:
        parent_node = {
            "parent": SchemaNode(
                needs={}, uses=parent, constructor_name="create", fn="run", config={}
            )
        }
    # noinspection PyTypeChecker
    return GraphSchema(
        {
            "my_node": SchemaNode(
                needs=needs or {},
                uses=uses,
                eager=eager,
                constructor_name=constructor_name,
                fn=run_fn,
                config={},
            ),
            **parent_node,
        }
    )
Exemple #24
0
def test_nlu_warn_of_competition_with_regex_extractor(
    monkeypatch: MonkeyPatch,
    component_types: List[Dict[Text, Text]],
    data_path: Text,
    should_warn: bool,
):
    importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[data_path])
    # there are no domain files for the above examples, so:
    monkeypatch.setattr(Domain, "check_missing_responses",
                        lambda *args, **kwargs: None)

    graph_schema = GraphSchema({
        f"{idx}": SchemaNode({}, component_type, "", "", {})
        for idx, component_type in enumerate(component_types)
    })
    validator = DefaultV1RecipeValidator(graph_schema)
    monkeypatch.setattr(validator, "_warn_if_some_training_data_is_unused",
                        lambda *args, **kwargs: None)

    if should_warn:
        with pytest.warns(
                UserWarning,
                match=(
                    f"You have an overlap between the "
                    f"'{RegexEntityExtractor.__name__}' and the statistical"),
        ):
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
        assert len(records) == 0
Exemple #25
0
def test_nlu_raise_if_more_than_one_tokenizer(nodes: Dict[Text, SchemaNode]):
    graph_schema = GraphSchema(nodes)
    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_schema)
    with pytest.raises(InvalidConfigException,
                       match=".* more than one tokenizer"):
        validator.validate(importer)
Exemple #26
0
def test_fn_exception(default_model_storage: ModelStorage):
    class BadFn(GraphComponent):
        @classmethod
        def create(
            cls,
            config: Dict[Text, Any],
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
        ) -> BadFn:
            return cls()

        def run(self) -> None:
            raise ValueError("Oh no!")

    node = GraphNode(
        node_name="bad_fn",
        component_class=BadFn,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "some_id"),
    )

    with pytest.raises(GraphComponentException):
        node()
Exemple #27
0
def test_writing_to_resource_during_training(
        default_model_storage: ModelStorage):
    node_name = "some_name"

    test_value_for_sub_directory = {"test": "test value sub dir"}
    test_value = {"test dir": "test value dir"}

    node = GraphNode(
        node_name=node_name,
        component_class=PersistableTestComponent,
        constructor_name="create",
        component_config={
            "test_value": test_value,
            "test_value_for_sub_directory": test_value_for_sub_directory,
        },
        fn_name="train",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    _, resource = node()

    assert resource == Resource(node_name)

    with default_model_storage.read_from(resource) as directory:
        assert (rasa.shared.utils.io.read_json_file(directory /
                                                    "test.json") == test_value)
        assert (rasa.shared.utils.io.read_json_file(
            directory / "sub_dir" /
            "test.json") == test_value_for_sub_directory)
Exemple #28
0
def _test_validation_warnings_with_default_configs(
    training_data: TrainingData,
    component_types: List[Type],
    warnings: Optional[List[Text]] = None,
):
    dummy_importer = DummyImporter(training_data=training_data)
    graph_schema = GraphSchema({
        f"{idx}": SchemaNode(
            needs={},
            uses=component_type,
            constructor_name="",
            fn="",
            config=component_type.get_default_config(),
        )
        for idx, component_type in enumerate(component_types)
    })
    validator = DefaultV1RecipeValidator(graph_schema)
    if not warnings:
        with pytest.warns(None) as records:
            validator.validate(dummy_importer)
            assert len(records) == 0, [
                warning.message for warning in records.list
            ]
    else:
        with pytest.warns(None) as records:
            validator.validate(dummy_importer)
        assert len(records) == len(warnings), ", ".join(warning.message.args[0]
                                                        for warning in records)
        assert [
            re.match(warning.message.args[0], expected_warning)
            for warning, expected_warning in zip(records, warnings)
        ]
Exemple #29
0
def test_invalid_module_error_when_deserializing_schemas(tmp_path: Path):
    graph_schema = GraphSchema(
        {
            "train": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="train",
                constructor_name="create",
                config={"some_config": 123455, "some more config": [{"nested": "hi"}]},
            )
        }
    )

    serialized = graph_schema.as_dict()

    # Pretend module is for some reason invalid
    serialized["nodes"]["train"]["uses"] = "invalid.class"

    # Dump it to make sure it's actually serializable
    file_path = tmp_path / "my_graph.yml"
    rasa.shared.utils.io.write_yaml(serialized, file_path)

    serialized_graph_schema_from_file = rasa.shared.utils.io.read_yaml_file(file_path)

    with pytest.raises(GraphSchemaException):
        _ = GraphSchema.from_dict(serialized_graph_schema_from_file)
def _get_example_schema(num_epochs: int = 5,
                        other_parameter: int = 10) -> GraphSchema:
    example_configs = [
        {
            "epochs": num_epochs,
            "other-parameter": other_parameter,
            "some-parameter": "bla",
        },
        {
            "epochs": num_epochs,
            "yet-other-parameter": 344
        },
        {
            "no-epochs-defined-here": None
        },
    ]
    return GraphSchema(
        nodes={
            f"node-{idx}": SchemaNode(needs={},
                                      uses=GraphComponent,
                                      constructor_name="",
                                      fn="",
                                      config=config)
            for idx, config in enumerate(example_configs)
        })