예제 #1
0
def test_validation_with_nlu_target_used_by_other_node():
    class NLUTargetConsumer(TestComponentWithoutRun):
        def run(self, nlu_target_output: List[Message]) -> List[Message]:
            pass

    graph_config = GraphSchema(
        {
            "A": SchemaNode(
                needs={},
                uses=TestNLUTarget,
                eager=True,
                constructor_name="create",
                fn="run",
                config={},
            ),
            "B": SchemaNode(
                needs={"nlu_target_output": "A"},
                uses=NLUTargetConsumer,
                eager=True,
                constructor_name="create",
                fn="run",
                config={},
            ),
        }
    )

    with pytest.raises(
        GraphSchemaValidationException, match="uses the NLU target 'A' as input"
    ):
        validation.validate(
            GraphModelConfiguration(
                train_schema=GraphSchema({}),
                predict_schema=graph_config,
                training_type=TrainingType.BOTH,
                language=None,
                core_target=None,
                nlu_target="A",
            )
        )
예제 #2
0
def test_serialize_graph_schema(tmp_path: Path):
    graph_schema = GraphSchema({
        "train":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="train",
            constructor_name="create",
            config={
                "some_config": 123455,
                "some more config": [{
                    "nested": "hi"
                }]
            },
        ),
        "load":
        SchemaNode(
            needs={"resource": "train"},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
            is_target=True,
            resource=Resource("test resource"),
        ),
    })

    serialized = graph_schema.as_dict()

    # Dump it to make sure it's actually serializable
    file_path = tmp_path / "my_graph.yml"
    rasa.shared.utils.io.write_yaml(serialized, file_path)

    serialized_graph_schema_from_file = rasa.shared.utils.io.read_yaml_file(
        file_path)
    graph_schema_from_file = GraphSchema.from_dict(
        serialized_graph_schema_from_file)

    assert graph_schema_from_file == graph_schema
예제 #3
0
def test_num_threads_interpolation():
    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        "data/graph_schemas/config_pretrained_embeddings_mitie_train_schema.yml"
    )
    expected_train_schema = GraphSchema.from_dict(expected_schema_as_dict)

    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        "data/graph_schemas/config_pretrained_embeddings_mitie_predict_schema.yml"
    )
    expected_predict_schema = GraphSchema.from_dict(expected_schema_as_dict)

    for node_name, node in expected_train_schema.nodes.items():
        if issubclass(
                node.uses,
            (
                SklearnIntentClassifier,
                MitieEntityExtractor,
                MitieIntentClassifier,
            ),
        ) and node_name.startswith("train_"):
            node.config["num_threads"] = 20

    config = rasa.shared.utils.io.read_yaml_file(
        "data/test_config/config_pretrained_embeddings_mitie.yml")

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(config, {"num_threads": 20})

    train_schema = model_config.train_schema
    for node_name, node in expected_train_schema.nodes.items():
        assert train_schema.nodes[node_name] == node

    assert train_schema == expected_train_schema

    predict_schema = model_config.predict_schema
    for node_name, node in expected_predict_schema.nodes.items():
        assert predict_schema.nodes[node_name] == node

    assert predict_schema == expected_predict_schema
예제 #4
0
def test_graph_trainer_train_logging_with_cached_components(
    tmp_path: Path,
    temp_cache: TrainingCache,
    train_with_schema: Callable,
    caplog: LogCaptureFixture,
):
    input_file = tmp_path / "input_file.txt"
    input_file.write_text("3")

    train_schema = GraphSchema(
        {
            "input": SchemaNode(
                needs={},
                uses=ProvideX,
                fn="provide",
                constructor_name="create",
                config={},
            ),
            "subtract": SchemaNode(
                needs={"i": "input"},
                uses=SubtractByX,
                fn="subtract_x",
                constructor_name="create",
                config={"x": 1},
                is_target=True,
                is_input=False,
            ),
            "cache_able_node": SchemaNode(
                needs={"suffix": "input"},
                uses=CacheableComponent,
                fn="run",
                constructor_name="create",
                config={},
                is_target=True,
                is_input=False,
            ),
        }
    )

    # Train to cache
    train_with_schema(train_schema, temp_cache)

    # Train a second time
    with caplog.at_level(logging.INFO, logger="rasa.engine.training.hooks"):
        train_with_schema(train_schema, temp_cache)

        assert set(caplog.messages) == {
            "Starting to train component 'SubtractByX'.",
            "Finished training component 'SubtractByX'.",
            "Restored component 'CacheableComponent' from cache.",
        }
예제 #5
0
def test_training_hook_does_not_cache_cached_component(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=PrecomputedValueProvider,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=PrecomputedValueProvider,
        constructor_name="create",
        component_config={"output": CacheableText("hi")},
        fn_name="get_value",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=PrecomputedValueProvider,
        config={"output": CacheableText("hi")},
        inputs={},
    )

    # The hook should not cache the output of a PrecomputedValueProvider
    assert not temp_cache.get_cached_output_fingerprint(fingerprint_key)
예제 #6
0
def test_core_warn_if_data_but_no_policy(monkeypatch: MonkeyPatch,
                                         policy_type: Optional[Type[Policy]]):

    importer = TrainingDataImporter.load_from_dict(
        domain_path="data/test_e2ebot/domain.yml",
        training_data_paths=[
            "data/test_e2ebot/data/nlu.yml",
            "data/test_e2ebot/data/stories.yml",
        ],
    )

    nodes = {
        "tokenizer": SchemaNode({}, WhitespaceTokenizer, "", "", {}),
        "nlu-component": SchemaNode({}, DIETClassifier, "", "", {}),
    }
    if policy_type is not None:
        nodes["some-policy"] = SchemaNode({}, policy_type, "", "", {})
    graph_schema = GraphSchema(nodes)

    validator = DefaultV1RecipeValidator(graph_schema)
    monkeypatch.setattr(
        validator,
        "_raise_if_a_rule_policy_is_incompatible_with_domain",
        lambda *args, **kwargs: None,
    )
    monkeypatch.setattr(validator, "_warn_if_no_rule_policy_is_contained",
                        lambda: None)
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )

    if policy_type is None:
        with pytest.warns(
                UserWarning,
                match="Found data for training policies but no policy"
        ) as records:
            validator.validate(importer)
        assert len(records) == 1
    else:
        with pytest.warns(
                UserWarning,
                match="Slot auto-fill has been removed in 3.0") as records:
            validator.validate(importer)
        assert all([
            warn.message.args[0].startswith("Slot auto-fill has been removed")
            for warn in records.list
        ])
예제 #7
0
def test_validation_with_core_target_wrong_type():
    graph_config = GraphSchema({
        "A":
        SchemaNode(
            needs={},
            uses=TestNLUTarget,
            eager=True,
            constructor_name="create",
            fn="run",
            config={},
        )
    })

    with pytest.raises(GraphSchemaValidationException,
                       match="Core model's .* invalid return type"):
        validation.validate(
            GraphModelConfiguration(
                train_schema=GraphSchema({}),
                predict_schema=graph_config,
                training_type=TrainingType.BOTH,
                language=None,
                core_target="A",
                nlu_target="A",
            ))
예제 #8
0
    def from_dict(cls, serialized: Dict[Text, Any]) -> ModelMetadata:
        """Loads `ModelMetadata` which has been serialized using `metadata.as_dict()`.

        Args:
            serialized: Serialized `ModelMetadata` (e.g. read from disk).

        Returns:
            Instantiated `ModelMetadata`.
        """
        from rasa.engine.graph import GraphSchema

        return ModelMetadata(
            trained_at=datetime.fromisoformat(serialized["trained_at"]),
            rasa_open_source_version=serialized["rasa_open_source_version"],
            model_id=serialized["model_id"],
            domain=Domain.from_dict(serialized["domain"]),
            train_schema=GraphSchema.from_dict(serialized["train_schema"]),
            predict_schema=GraphSchema.from_dict(serialized["predict_schema"]),
            training_type=TrainingType(serialized["training_type"]),
            project_fingerprint=serialized["project_fingerprint"],
            core_target=serialized["core_target"],
            nlu_target=serialized["nlu_target"],
            language=serialized["language"],
        )
예제 #9
0
def trained_ted(
    tmp_path_factory: TempPathFactory, moodbot_domain_path: Path,
) -> TEDPolicyGraphComponent:
    training_files = "data/test_moodbot/data/stories.yml"
    domain = Domain.load(moodbot_domain_path)
    trackers = training.load_data(str(training_files), domain)
    policy = TEDPolicyGraphComponent.create(
        {**TEDPolicyGraphComponent.get_default_config(), EPOCHS: 1},
        LocalModelStorage.create(tmp_path_factory.mktemp("storage")),
        Resource("ted"),
        ExecutionContext(GraphSchema({})),
    )
    policy.train(trackers, domain)

    return policy
예제 #10
0
def test_validation_with_missing_nlu_target():
    graph_config = GraphSchema({
        "A":
        SchemaNode(
            needs={},
            uses=TestNLUTarget,
            eager=True,
            constructor_name="create",
            fn="run",
            config={},
        ),
    })

    with pytest.raises(GraphSchemaValidationException,
                       match="no target for the 'nlu_target'"):
        validation.validate(
            GraphModelConfiguration(
                train_schema=GraphSchema({}),
                predict_schema=graph_config,
                training_type=TrainingType.BOTH,
                language=None,
                core_target=None,
                nlu_target=None,
            ))
예제 #11
0
def test_graph_trainer_train_logging(
    tmp_path: Path,
    temp_cache: TrainingCache,
    train_with_schema: Callable,
    caplog: LogCaptureFixture,
):

    input_file = tmp_path / "input_file.txt"
    input_file.write_text("3")

    train_schema = GraphSchema({
        "input":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
        ),
        "subtract 2":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
            is_target=True,
            is_input=True,
        ),
        "subtract":
        SchemaNode(
            needs={"i": "input"},
            uses=SubtractByX,
            fn="subtract_x",
            constructor_name="create",
            config={"x": 1},
            is_target=True,
            is_input=False,
        ),
    })

    with caplog.at_level(logging.INFO, logger="rasa.engine.training.hooks"):
        train_with_schema(train_schema, temp_cache)

    assert caplog.messages == [
        "Starting to train component 'SubtractByX'.",
        "Finished training component 'SubtractByX'.",
    ]
예제 #12
0
def test_config_with_nested_dict_override(default_model_storage: ModelStorage):
    class ComponentWithNestedDictConfig(GraphComponent):
        @staticmethod
        def get_default_config() -> Dict[Text, Any]:
            return {"nested-dict": {"key1": "value1", "key2": "value2"}}

        @classmethod
        def create(
            cls,
            config: Dict,
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
            **kwargs: Any,
        ) -> ComponentWithNestedDictConfig:
            return cls()

        def run(self) -> None:
            return None

    node = GraphNode(
        node_name="nested_dict_config",
        component_class=ComponentWithNestedDictConfig,
        constructor_name="create",
        component_config={"nested-dict": {
            "key2": "override-value2"
        }},
        fn_name="run",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    expected_config = {
        "nested-dict": {
            "key1": "value1",
            "key2": "override-value2"
        }
    }

    for key, value in expected_config.items():
        assert key in node._component_config
        if isinstance(value, dict):
            for nested_key, nested_value in expected_config[key].items():
                assert nested_key in node._component_config[key]
                assert node._component_config[key][nested_key] == nested_value
예제 #13
0
def test_can_use_alternate_constructor(default_model_storage: ModelStorage):
    node = GraphNode(
        node_name="provide",
        component_class=ProvideX,
        constructor_name="create_with_2",
        component_config={},
        fn_name="provide",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    result = node()
    assert result == ("provide", 2)
예제 #14
0
파일: conftest.py 프로젝트: ChenHuaYou/rasa
 def create_component(component_class: Type[GraphComponent],
                      config: Dict[Text,
                                   Any], idx: int) -> GraphComponent:
     node_name = f"{component_class.__name__}_{idx}"
     execution_context = ExecutionContext(GraphSchema({}),
                                          node_name=node_name)
     resource = Resource(node_name)
     return component_class.create(
         {
             **component_class.get_default_config(),
             **config
         },
         default_model_storage,
         resource,
         execution_context,
     )
예제 #15
0
def test_nlu_warn_of_competing_extractors(
        component_types: List[Type[GraphComponent]], should_warn: bool):
    graph_schema = GraphSchema({
        f"{idx}": SchemaNode({}, component_type, "", "", {})
        for idx, component_type in enumerate(component_types)
    })
    importer = DummyImporter()
    nlu_validator = DefaultV1RecipeValidator(graph_schema)
    if should_warn:
        with pytest.warns(UserWarning,
                          match=".*defined multiple entity extractors"):
            nlu_validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            nlu_validator.validate(importer)
        assert len(records) == 0
예제 #16
0
def test_resources_fingerprints_are_unique_when_cached(
    temp_cache: LocalTrainingCache,
    train_with_schema: Callable,
):
    train_schema = GraphSchema({
        "train":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="train",
            constructor_name="create",
            config={"test_value": "4"},
            is_target=True,
        ),
        "process":
        SchemaNode(
            needs={"resource": "train"},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
        ),
        "assert_node":
        SchemaNode(
            needs={"i": "process"},
            uses=AssertComponent,
            fn="run_assert",
            constructor_name="create",
            config={"value_to_assert": "4"},
            is_target=True,
        ),
    })

    # Train to cache
    train_with_schema(train_schema, temp_cache)

    train_schema.nodes["train"].config["test_value"] = "5"
    train_schema.nodes["assert_node"].config["value_to_assert"] = "5"
    train_with_schema(train_schema, temp_cache)

    # Add something to the config so only "assert_node" re-runs.
    train_schema.nodes["assert_node"].config["something"] = "something"
    # This breaks when `Resource`s use the node name as a fingerprint.
    # This is because the `Resource` for the first run is retrieved from the cache which
    # returns 4 whereas it should be the second resource which returns 5, and the schema
    # assert_node expects 5 now.
    train_with_schema(train_schema, temp_cache)
예제 #17
0
def test_nlu_raise_if_featurizers_are_not_compatible(
    component_types_and_configs: List[Tuple[Type[GraphComponent],
                                            Dict[Text, Any], Text]],
    should_raise: bool,
):
    graph_schema = GraphSchema({
        f"{node_name}": SchemaNode({}, component_type, "", fn, config)
        for (node_name, component_type, config,
             fn) in component_types_and_configs
    })
    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_schema)
    if should_raise:
        with pytest.raises(InvalidConfigException):
            validator.validate(importer)
    else:
        validator.validate(importer)
예제 #18
0
def test_epoch_fraction_cli_param():
    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        "data/graph_schemas/default_config_finetune_epoch_fraction_schema.yml")
    expected_train_schema = GraphSchema.from_dict(expected_schema_as_dict)

    config = rasa.shared.utils.io.read_yaml_file(
        "rasa/shared/importers/default_config.yml")

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(
        config, {"finetuning_epoch_fraction": 0.5}, is_finetuning=True)

    train_schema = model_config.train_schema
    for node_name, node in expected_train_schema.nodes.items():
        assert train_schema.nodes[node_name] == node

    assert train_schema == expected_train_schema
예제 #19
0
def test_graph_trainer_with_non_cacheable_components(
    temp_cache: TrainingCache,
    tmp_path: Path,
    train_with_schema: Callable,
    spy_on_all_components: Callable,
):

    input_file = tmp_path / "input_file.txt"
    input_file.write_text("3")

    train_schema = GraphSchema({
        "input":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
        ),
        "subtract":
        SchemaNode(
            needs={"i": "input"},
            uses=SubtractByX,
            fn="subtract_x",
            constructor_name="create",
            config={"x": 1},
            is_target=True,
        ),
    })

    # The first train should call all the components.
    mocks = spy_on_all_components(train_schema)
    train_with_schema(train_schema, temp_cache)
    assert node_call_counts(mocks) == {
        "input": 1,
        "subtract": 1,
    }

    # Nothing has changed but none of the components can cache so all will have to
    # run again.
    mocks = spy_on_all_components(train_schema)
    train_with_schema(train_schema, temp_cache)
    assert node_call_counts(mocks) == {
        "input": 1,
        "subtract": 1,
    }
예제 #20
0
def test_core_warn_if_policy_priorities_are_not_unique(
    monkeypatch: MonkeyPatch,
    policy_types: Set[Type[Policy]],
    num_duplicates: bool,
    priority: int,
):

    assert (
        len(policy_types) >= priority + num_duplicates
    ), f"This tests needs at least {priority+num_duplicates} many types."

    # start with a schema where node i has priority i
    nodes = {
        f"{idx}": SchemaNode("", policy_type, "", "", {"priority": idx})
        for idx, policy_type in enumerate(policy_types)
    }

    # give nodes p+1, .., p+num_duplicates-1 priority "priority"
    for idx in range(num_duplicates):
        nodes[f"{priority+idx+1}"].config["priority"] = priority

    validator = DefaultV1RecipeValidator(graph_schema=GraphSchema(nodes))
    monkeypatch.setattr(
        validator,
        "_warn_if_rule_based_data_is_unused_or_missing",
        lambda *args, **kwargs: None,
    )

    importer = DummyImporter()

    if num_duplicates > 0:
        duplicates = [
            node.uses for idx_str, node in nodes.items()
            if priority <= int(idx_str) <= priority + num_duplicates
        ]
        expected_message = (
            f"Found policies {_types_to_str(duplicates)} with same priority {priority} "
        )
        expected_message = re.escape(expected_message)
        with pytest.warns(UserWarning, match=expected_message):
            validator.validate(importer)
    else:
        with pytest.warns(None) as records:
            validator.validate(importer)
        assert len(records) == 0
예제 #21
0
def test_component_config(x: Optional[int], output: int,
                          default_model_storage: ModelStorage):
    node = GraphNode(
        node_name="subtract",
        component_class=SubtractByX,
        constructor_name="create",
        component_config={"x": x} if x else {},
        fn_name="subtract_x",
        inputs={"i": "input_node"},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    result = node(("input_node", 5))

    assert result == ("subtract", output)
예제 #22
0
def test_core_raise_if_policy_has_no_priority():
    class PolicyWithoutPriority(Policy, GraphComponent):
        def __init__(
            self,
            config: Dict[Text, Any],
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
        ) -> None:
            super().__init__(config, model_storage, resource,
                             execution_context)

    nodes = {"policy": SchemaNode("", PolicyWithoutPriority, "", "", {})}
    graph_schema = GraphSchema(nodes)
    importer = DummyImporter()
    validator = DefaultV1RecipeValidator(graph_schema)
    with pytest.raises(InvalidConfigException,
                       match="Every policy must have a priority value"):
        validator.validate(importer)
예제 #23
0
def test_eager_and_not_eager(eager: bool, default_model_storage: ModelStorage):
    run_mock = Mock()
    create_mock = Mock()

    class SpyComponent(GraphComponent):
        @classmethod
        def create(
            cls,
            config: Dict,
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
        ) -> SpyComponent:
            create_mock()
            return cls()

        def run(self):
            return run_mock()

    node = GraphNode(
        node_name="spy_node",
        component_class=SpyComponent,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={},
        eager=eager,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    if eager:
        assert create_mock.called
    else:
        assert not create_mock.called

    assert not run_mock.called

    node()

    assert create_mock.call_count == 1
    assert run_mock.called
예제 #24
0
def test_no_target(default_model_storage: ModelStorage):
    graph_schema = GraphSchema({
        "provide":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
        )
    })
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=graph_schema,
                                           model_id="1"),
    )
    results = runner.run()
    assert not results
예제 #25
0
def test_can_use_alternate_constructor(default_model_storage: ModelStorage):
    graph_schema = GraphSchema({
        "provide":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create_with_2",
            config={},
            is_target=True,
        )
    })
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=graph_schema,
                                           model_id="1"),
    )
    results = runner.run()
    assert results["provide"] == 2
예제 #26
0
def test_calling_component(default_model_storage: ModelStorage):
    node = GraphNode(
        node_name="add_node",
        component_class=AddInputs,
        constructor_name="create",
        component_config={},
        fn_name="add",
        inputs={
            "i1": "input_node1",
            "i2": "input_node2"
        },
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    result = node(("input_node1", 3), ("input_node2", 4))

    assert result == ("add_node", 7)
예제 #27
0
def test_execution_context(default_model_storage: ModelStorage):
    context = ExecutionContext(GraphSchema({}), "some_id")
    node = GraphNode(
        node_name="execution_context_aware",
        component_class=ExecutionContextAware,
        constructor_name="create",
        component_config={},
        fn_name="get_execution_context",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=context,
    )

    context.model_id = "a_new_id"

    result = node()[1]
    assert result.model_id == "some_id"
    assert result.node_name == "execution_context_aware"
예제 #28
0
def test_input_value_is_node_name(default_model_storage: ModelStorage):
    graph_schema = GraphSchema({
        "provide":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
            is_target=True,
        )
    })
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=graph_schema,
                                           model_id="1"),
    )
    with pytest.raises(GraphRunError):
        runner.run(inputs={"input": "provide"})
예제 #29
0
def _create_graph_schema_from_requirements(
    node_needs_requires: List[Tuple[int, List[int], List[int]]],
    targets: List[int],
    use_subclass: bool,
) -> GraphSchema:
    # create some component types
    component_types = {
        node: _create_component_type_and_subtype_with_run_function(
            component_type_name=f"class_{node}", needs=needs
        )
        for node, needs, _ in node_needs_requires
    }

    # add required components
    for node, _, required_components in node_needs_requires:
        for component_type in component_types[node]:
            component_type.required_components = Mock(
                return_value=[
                    component_types[required][0] for required in required_components
                ]
            )

    # create graph schema
    graph_schema = GraphSchema(
        {
            f"node-{node}": SchemaNode(
                needs={
                    f"param{param}": f"node-{needed_node}"
                    for param, needed_node in enumerate(needs)
                },
                uses=component_types[node][use_subclass],  # use subclass if required
                fn="run",
                constructor_name="create",
                config={},
                is_target=node in targets,
            )
            for node, needs, _ in node_needs_requires
        }
    )
    graph_schema.nodes.update(DEFAULT_PREDICT_SCHEMA.nodes)
    return graph_schema
예제 #30
0
파일: test_dask.py 프로젝트: zoovu/rasa
def test_non_eager_can_use_inputs_for_constructor(default_model_storage: ModelStorage):
    graph_schema = GraphSchema(
        {
            "provide": SchemaNode(
                needs={"x": "input"},
                uses=ProvideX,
                fn="provide",
                constructor_name="create",
                config={},
                eager=False,
                is_target=True,
            ),
        }
    )
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=graph_schema, model_id="1"),
    )
    results = runner.run(inputs={"input": 5})
    assert results["provide"] == 5