示例#1
0
def test_validate_validates_required_components(
    test_case: List[RequiredComponentsTestCase],
    is_train_graph: bool,
    test_subclass: bool,
):
    train_schema = GraphSchema({})
    predict_schema = DEFAULT_PREDICT_SCHEMA
    graph_schema = _create_graph_schema_from_requirements(
        node_needs_requires=test_case.node_needs_requires_tuples,
        targets=test_case.targets,
        use_subclass=test_subclass,
    )

    if is_train_graph:
        train_schema = graph_schema
    else:
        predict_schema = graph_schema
    graph_config = GraphModelConfiguration(train_schema, predict_schema,
                                           TrainingType.BOTH, None, None,
                                           "nlu_target")

    num_unmet = test_case.num_unmet_requirements
    if num_unmet == 0:
        validation.validate(graph_config)
    else:
        message = f"{num_unmet} components are missing"
        with pytest.raises(GraphSchemaValidationException, match=message):
            validation.validate(graph_config)
示例#2
0
def test_validation_with_core_target_wrong_type():
    graph_config = GraphSchema(
        {
            "A":
            SchemaNode(
                needs={},
                uses=TestNLUTarget,
                eager=True,
                constructor_name="create",
                fn="run",
                config={},
            ),
        }, )

    with pytest.raises(
            GraphSchemaValidationException,
            match="Core model's .* invalid return type",
    ):
        validation.validate(
            GraphModelConfiguration(
                train_schema=GraphSchema({}),
                predict_schema=graph_config,
                training_type=TrainingType.BOTH,
                language=None,
                core_target="A",
                nlu_target="A",
            ))
示例#3
0
def test_too_many_supplied_params():
    graph_config = create_test_schema(uses=TestComponentWithRun,
                                      needs={"some_param": "parent"})

    with pytest.raises(GraphSchemaValidationException,
                       match="does not accept a parameter"):
        validation.validate(graph_config)
示例#4
0
def test_parent_supplying_wrong_type_to_constructor():
    class MyUnreliableParent(TestComponentWithoutRun):
        def run(self) -> Domain:
            pass

    class MyComponent(TestComponentWithRun):
        @classmethod
        def load(
            cls,
            config: Dict[Text, Any],
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
            some_param: TrainingData,
        ) -> GraphComponent:
            pass

    graph_config = create_test_schema(
        uses=MyComponent,
        eager=False,
        constructor_name="load",
        parent=MyUnreliableParent,
        needs={"some_param": "parent"},
    )

    with pytest.raises(GraphSchemaValidationException,
                       match="expects an input of type"):
        validation.validate(graph_config, )
示例#5
0
def test_validation_with_missing_nlu_target():
    graph_config = GraphSchema(
        {
            "A": SchemaNode(
                needs={},
                uses=TestNLUTarget,
                eager=True,
                constructor_name="create",
                fn="run",
                config={},
            )
        }
    )

    with pytest.raises(
        GraphSchemaValidationException, match="no target for the 'nlu_target'"
    ):
        validation.validate(
            GraphModelConfiguration(
                train_schema=GraphSchema({}),
                predict_schema=graph_config,
                training_type=TrainingType.BOTH,
                language=None,
                core_target=None,
                nlu_target=None,
            )
        )
示例#6
0
def test_graph_component_fn_does_not_exist():
    schema = create_test_schema(uses=TestComponentWithRun, run_fn="some_fn")

    with pytest.raises(
        GraphSchemaValidationException, match="specified method 'some_fn'"
    ):
        validation.validate(schema, language=None, is_train_graph=True)
示例#7
0
def test_graph_component_fn_does_not_exist():
    graph_config = create_test_schema(uses=TestComponentWithRun,
                                      run_fn="some_fn")

    with pytest.raises(GraphSchemaValidationException,
                       match="required method 'some_fn'"):
        validation.validate(graph_config, )
示例#8
0
def test_run_param_satifisfied_due_to_default():
    class MyComponent(TestComponentWithoutRun):
        def run(self, some_param: TrainingData = TrainingData()) -> TrainingData:
            pass

    schema = create_test_schema(uses=MyComponent)

    validation.validate(schema, language=None, is_train_graph=True)
示例#9
0
def test_graph_with_fingerprintable_output_subclass():
    class MyComponent(TestComponentWithoutRun):
        def run(self) -> MyTrainingData:
            pass

    schema = create_test_schema(uses=MyComponent)

    validation.validate(schema, language=None, is_train_graph=True)
示例#10
0
def test_run_param_satifisfied_due_to_default():
    class MyComponent(TestComponentWithoutRun):
        def run(self, some_param: TrainingData = TrainingData()) -> TrainingData:
            pass

    graph_config = create_test_schema(uses=MyComponent)

    validation.validate(graph_config)
示例#11
0
def test_graph_with_cls_type_hint():
    class MyComponent(TestComponentWithClsTypeHints):
        def run(self) -> MyTrainingData:
            pass

    graph_config = create_test_schema(uses=MyComponent)

    validation.validate(graph_config)
示例#12
0
def test_predict_graph_output_is_not_fingerprintable():
    class MyComponent(TestComponentWithoutRun):
        def run(self) -> int:
            pass

    schema = create_test_schema(uses=MyComponent)

    validation.validate(schema, language=None, is_train_graph=False)
示例#13
0
def test_predict_graph_output_is_not_fingerprintable():
    class MyComponent(TestComponentWithoutRun):
        def run(self) -> int:
            pass

    graph_config = create_test_schema(uses=MyComponent, is_train_graph=False)

    validation.validate(graph_config)
示例#14
0
def test_graph_with_fingerprintable_output_subclass():
    class MyComponent(TestComponentWithoutRun):
        def run(self) -> MyTrainingData:
            pass

    graph_config = create_test_schema(uses=MyComponent)

    validation.validate(graph_config, )
示例#15
0
def test_run_param_not_satisfied():
    class MyComponent(TestComponentWithoutRun):
        def run(self, some_param: TrainingData) -> TrainingData:
            pass

    graph_config = create_test_schema(uses=MyComponent)

    with pytest.raises(GraphSchemaValidationException, match="needs the param"):
        validation.validate(graph_config)
示例#16
0
def test_graph_component_is_no_graph_component():
    class MyComponent:
        def other(self) -> TrainingData:
            pass

    graph_config = create_test_schema(uses=MyComponent)

    with pytest.raises(GraphSchemaValidationException, match="implement .+ interface"):
        validation.validate(graph_config)
示例#17
0
def test_graph_output_is_not_fingerprintable_int():
    class MyComponent(TestComponentWithoutRun):
        def run(self) -> int:
            pass

    schema = create_test_schema(uses=MyComponent)

    with pytest.raises(GraphSchemaValidationException, match="fingerprintable"):
        validation.validate(schema, language=None, is_train_graph=True)
示例#18
0
def test_graph_output_is_not_fingerprintable_None():
    class MyComponent(TestComponentWithoutRun):
        def run(self) -> None:
            pass

    graph_config = create_test_schema(uses=MyComponent)

    with pytest.raises(GraphSchemaValidationException, match="fingerprintable"):
        validation.validate(graph_config)
示例#19
0
def test_graph_component_is_no_graph_component():
    class MyComponent:
        def other(self) -> TrainingData:
            pass

    schema = create_test_schema(uses=MyComponent)

    with pytest.raises(GraphSchemaValidationException, match="implement .+ interface"):
        validation.validate(schema, language=None, is_train_graph=True)
示例#20
0
def test_parent_is_missing():
    graph_config = create_test_schema(
        uses=TestComponentWithRunAndParam,
        needs={"training_data": "not existing parent"},
    )

    with pytest.raises(GraphSchemaValidationException,
                       match="this component is not part"):
        validation.validate(graph_config)
示例#21
0
def test_run_param_not_satisfied():
    class MyComponent(TestComponentWithoutRun):
        def run(self, some_param: TrainingData) -> TrainingData:
            pass

    schema = create_test_schema(uses=MyComponent)

    with pytest.raises(GraphSchemaValidationException, match="needs the param"):
        validation.validate(schema, language=None, is_train_graph=True)
示例#22
0
def test_too_many_supplied_params_but_kwargs():
    class MyComponent(TestComponentWithoutRun):
        def run(self, **kwargs: Any) -> TrainingData:
            pass

    graph_config = create_test_schema(uses=MyComponent,
                                      needs={"some_param": "parent"},
                                      parent=TestComponentWithRun)

    validation.validate(graph_config, )
示例#23
0
def test_run_fn_with_variable_length_positional_param():
    class MyComponent(TestComponentWithoutRun):
        def run(self, *args: Any, some_param: TrainingData) -> TrainingData:
            pass

    graph_config = create_test_schema(uses=MyComponent,
                                      needs={"some_param": "parent"},
                                      parent=TestComponentWithRun)

    validation.validate(graph_config, )
示例#24
0
def test_too_many_supplied_params_but_kwargs():
    class MyComponent(TestComponentWithoutRun):
        def run(self, **kwargs: Any) -> TrainingData:
            pass

    schema = create_test_schema(
        uses=MyComponent, needs={"some_param": "parent"}, parent=TestComponentWithRun
    )

    validation.validate(schema, language=None, is_train_graph=True)
示例#25
0
def test_graph_satisfied_package_requirements(required_packages: List[Text]):
    class MyComponent(TestComponentWithRun):
        @staticmethod
        def required_packages() -> List[Text]:
            """Any extra python dependencies required for this component to run."""
            return required_packages

    graph_config = create_test_schema(uses=MyComponent)

    validation.validate(graph_config, )
示例#26
0
def test_run_fn_with_variable_length_positional_param():
    class MyComponent(TestComponentWithoutRun):
        def run(self, *args: Any, some_param: TrainingData) -> TrainingData:
            pass

    schema = create_test_schema(
        uses=MyComponent, needs={"some_param": "parent"}, parent=TestComponentWithRun
    )

    validation.validate(schema, language=None, is_train_graph=True)
示例#27
0
def test_graph_output_missing_type_annotation():
    class MyComponent(TestComponentWithoutRun):
        def run(self):
            pass

    graph_config = create_test_schema(uses=MyComponent)

    with pytest.raises(GraphSchemaValidationException,
                       match="does not have a type annotation"):
        validation.validate(graph_config, )
示例#28
0
def test_graph_constructor_missing():
    class MyComponent(TestComponentWithoutRun):
        def run(self) -> TrainingData:
            pass

    graph_config = create_test_schema(uses=MyComponent,
                                      constructor_name="invalid")

    with pytest.raises(GraphSchemaValidationException,
                       match="required method 'invalid'"):
        validation.validate(graph_config, )
示例#29
0
def test_cycle(is_train_graph: bool):
    class MyTestComponent(TestComponentWithoutRun):
        def run(self, training_data: TrainingData) -> TrainingData:
            pass

    train_schema = GraphSchema({})
    predict_schema = DEFAULT_PREDICT_SCHEMA

    schema = GraphSchema({
        "A":
        SchemaNode(
            needs={"training_data": "B"},
            uses=MyTestComponent,
            eager=True,
            constructor_name="create",
            fn="run",
            is_target=True,
            config={},
        ),
        "B":
        SchemaNode(
            needs={"training_data": "C"},
            uses=MyTestComponent,
            eager=True,
            constructor_name="create",
            fn="run",
            config={},
        ),
        "C":
        SchemaNode(
            needs={"training_data": "A"},
            uses=MyTestComponent,
            eager=True,
            constructor_name="create",
            fn="run",
            config={},
        ),
    })

    if is_train_graph:
        train_schema = schema
    else:
        predict_schema = schema

    with pytest.raises(GraphSchemaValidationException, match="Cycles"):
        validation.validate(
            GraphModelConfiguration(
                train_schema=train_schema,
                predict_schema=predict_schema,
                training_type=TrainingType.BOTH,
                language=None,
                core_target=None,
                nlu_target="nlu_target",
            ))
示例#30
0
def test_graph_missing_package_requirements(required_packages: List[Text]):
    class MyComponent(TestComponentWithRun):
        @staticmethod
        def required_packages() -> List[Text]:
            """Any extra python dependencies required for this component to run."""
            return required_packages

    graph_config = create_test_schema(uses=MyComponent)

    with pytest.raises(GraphSchemaValidationException, match="not installed"):
        validation.validate(graph_config, )