def test_validate_validates_required_components( test_case: List[RequiredComponentsTestCase], is_train_graph: bool, test_subclass: bool, ): train_schema = GraphSchema({}) predict_schema = DEFAULT_PREDICT_SCHEMA graph_schema = _create_graph_schema_from_requirements( node_needs_requires=test_case.node_needs_requires_tuples, targets=test_case.targets, use_subclass=test_subclass, ) if is_train_graph: train_schema = graph_schema else: predict_schema = graph_schema graph_config = GraphModelConfiguration(train_schema, predict_schema, TrainingType.BOTH, None, None, "nlu_target") num_unmet = test_case.num_unmet_requirements if num_unmet == 0: validation.validate(graph_config) else: message = f"{num_unmet} components are missing" with pytest.raises(GraphSchemaValidationException, match=message): validation.validate(graph_config)
def test_validation_with_core_target_wrong_type(): graph_config = GraphSchema( { "A": SchemaNode( needs={}, uses=TestNLUTarget, eager=True, constructor_name="create", fn="run", config={}, ), }, ) with pytest.raises( GraphSchemaValidationException, match="Core model's .* invalid return type", ): validation.validate( GraphModelConfiguration( train_schema=GraphSchema({}), predict_schema=graph_config, training_type=TrainingType.BOTH, language=None, core_target="A", nlu_target="A", ))
def test_too_many_supplied_params(): graph_config = create_test_schema(uses=TestComponentWithRun, needs={"some_param": "parent"}) with pytest.raises(GraphSchemaValidationException, match="does not accept a parameter"): validation.validate(graph_config)
def test_parent_supplying_wrong_type_to_constructor(): class MyUnreliableParent(TestComponentWithoutRun): def run(self) -> Domain: pass class MyComponent(TestComponentWithRun): @classmethod def load( cls, config: Dict[Text, Any], model_storage: ModelStorage, resource: Resource, execution_context: ExecutionContext, some_param: TrainingData, ) -> GraphComponent: pass graph_config = create_test_schema( uses=MyComponent, eager=False, constructor_name="load", parent=MyUnreliableParent, needs={"some_param": "parent"}, ) with pytest.raises(GraphSchemaValidationException, match="expects an input of type"): validation.validate(graph_config, )
def test_validation_with_missing_nlu_target(): graph_config = GraphSchema( { "A": SchemaNode( needs={}, uses=TestNLUTarget, eager=True, constructor_name="create", fn="run", config={}, ) } ) with pytest.raises( GraphSchemaValidationException, match="no target for the 'nlu_target'" ): validation.validate( GraphModelConfiguration( train_schema=GraphSchema({}), predict_schema=graph_config, training_type=TrainingType.BOTH, language=None, core_target=None, nlu_target=None, ) )
def test_graph_component_fn_does_not_exist(): schema = create_test_schema(uses=TestComponentWithRun, run_fn="some_fn") with pytest.raises( GraphSchemaValidationException, match="specified method 'some_fn'" ): validation.validate(schema, language=None, is_train_graph=True)
def test_graph_component_fn_does_not_exist(): graph_config = create_test_schema(uses=TestComponentWithRun, run_fn="some_fn") with pytest.raises(GraphSchemaValidationException, match="required method 'some_fn'"): validation.validate(graph_config, )
def test_run_param_satifisfied_due_to_default(): class MyComponent(TestComponentWithoutRun): def run(self, some_param: TrainingData = TrainingData()) -> TrainingData: pass schema = create_test_schema(uses=MyComponent) validation.validate(schema, language=None, is_train_graph=True)
def test_graph_with_fingerprintable_output_subclass(): class MyComponent(TestComponentWithoutRun): def run(self) -> MyTrainingData: pass schema = create_test_schema(uses=MyComponent) validation.validate(schema, language=None, is_train_graph=True)
def test_run_param_satifisfied_due_to_default(): class MyComponent(TestComponentWithoutRun): def run(self, some_param: TrainingData = TrainingData()) -> TrainingData: pass graph_config = create_test_schema(uses=MyComponent) validation.validate(graph_config)
def test_graph_with_cls_type_hint(): class MyComponent(TestComponentWithClsTypeHints): def run(self) -> MyTrainingData: pass graph_config = create_test_schema(uses=MyComponent) validation.validate(graph_config)
def test_predict_graph_output_is_not_fingerprintable(): class MyComponent(TestComponentWithoutRun): def run(self) -> int: pass schema = create_test_schema(uses=MyComponent) validation.validate(schema, language=None, is_train_graph=False)
def test_predict_graph_output_is_not_fingerprintable(): class MyComponent(TestComponentWithoutRun): def run(self) -> int: pass graph_config = create_test_schema(uses=MyComponent, is_train_graph=False) validation.validate(graph_config)
def test_graph_with_fingerprintable_output_subclass(): class MyComponent(TestComponentWithoutRun): def run(self) -> MyTrainingData: pass graph_config = create_test_schema(uses=MyComponent) validation.validate(graph_config, )
def test_run_param_not_satisfied(): class MyComponent(TestComponentWithoutRun): def run(self, some_param: TrainingData) -> TrainingData: pass graph_config = create_test_schema(uses=MyComponent) with pytest.raises(GraphSchemaValidationException, match="needs the param"): validation.validate(graph_config)
def test_graph_component_is_no_graph_component(): class MyComponent: def other(self) -> TrainingData: pass graph_config = create_test_schema(uses=MyComponent) with pytest.raises(GraphSchemaValidationException, match="implement .+ interface"): validation.validate(graph_config)
def test_graph_output_is_not_fingerprintable_int(): class MyComponent(TestComponentWithoutRun): def run(self) -> int: pass schema = create_test_schema(uses=MyComponent) with pytest.raises(GraphSchemaValidationException, match="fingerprintable"): validation.validate(schema, language=None, is_train_graph=True)
def test_graph_output_is_not_fingerprintable_None(): class MyComponent(TestComponentWithoutRun): def run(self) -> None: pass graph_config = create_test_schema(uses=MyComponent) with pytest.raises(GraphSchemaValidationException, match="fingerprintable"): validation.validate(graph_config)
def test_graph_component_is_no_graph_component(): class MyComponent: def other(self) -> TrainingData: pass schema = create_test_schema(uses=MyComponent) with pytest.raises(GraphSchemaValidationException, match="implement .+ interface"): validation.validate(schema, language=None, is_train_graph=True)
def test_parent_is_missing(): graph_config = create_test_schema( uses=TestComponentWithRunAndParam, needs={"training_data": "not existing parent"}, ) with pytest.raises(GraphSchemaValidationException, match="this component is not part"): validation.validate(graph_config)
def test_run_param_not_satisfied(): class MyComponent(TestComponentWithoutRun): def run(self, some_param: TrainingData) -> TrainingData: pass schema = create_test_schema(uses=MyComponent) with pytest.raises(GraphSchemaValidationException, match="needs the param"): validation.validate(schema, language=None, is_train_graph=True)
def test_too_many_supplied_params_but_kwargs(): class MyComponent(TestComponentWithoutRun): def run(self, **kwargs: Any) -> TrainingData: pass graph_config = create_test_schema(uses=MyComponent, needs={"some_param": "parent"}, parent=TestComponentWithRun) validation.validate(graph_config, )
def test_run_fn_with_variable_length_positional_param(): class MyComponent(TestComponentWithoutRun): def run(self, *args: Any, some_param: TrainingData) -> TrainingData: pass graph_config = create_test_schema(uses=MyComponent, needs={"some_param": "parent"}, parent=TestComponentWithRun) validation.validate(graph_config, )
def test_too_many_supplied_params_but_kwargs(): class MyComponent(TestComponentWithoutRun): def run(self, **kwargs: Any) -> TrainingData: pass schema = create_test_schema( uses=MyComponent, needs={"some_param": "parent"}, parent=TestComponentWithRun ) validation.validate(schema, language=None, is_train_graph=True)
def test_graph_satisfied_package_requirements(required_packages: List[Text]): class MyComponent(TestComponentWithRun): @staticmethod def required_packages() -> List[Text]: """Any extra python dependencies required for this component to run.""" return required_packages graph_config = create_test_schema(uses=MyComponent) validation.validate(graph_config, )
def test_run_fn_with_variable_length_positional_param(): class MyComponent(TestComponentWithoutRun): def run(self, *args: Any, some_param: TrainingData) -> TrainingData: pass schema = create_test_schema( uses=MyComponent, needs={"some_param": "parent"}, parent=TestComponentWithRun ) validation.validate(schema, language=None, is_train_graph=True)
def test_graph_output_missing_type_annotation(): class MyComponent(TestComponentWithoutRun): def run(self): pass graph_config = create_test_schema(uses=MyComponent) with pytest.raises(GraphSchemaValidationException, match="does not have a type annotation"): validation.validate(graph_config, )
def test_graph_constructor_missing(): class MyComponent(TestComponentWithoutRun): def run(self) -> TrainingData: pass graph_config = create_test_schema(uses=MyComponent, constructor_name="invalid") with pytest.raises(GraphSchemaValidationException, match="required method 'invalid'"): validation.validate(graph_config, )
def test_cycle(is_train_graph: bool): class MyTestComponent(TestComponentWithoutRun): def run(self, training_data: TrainingData) -> TrainingData: pass train_schema = GraphSchema({}) predict_schema = DEFAULT_PREDICT_SCHEMA schema = GraphSchema({ "A": SchemaNode( needs={"training_data": "B"}, uses=MyTestComponent, eager=True, constructor_name="create", fn="run", is_target=True, config={}, ), "B": SchemaNode( needs={"training_data": "C"}, uses=MyTestComponent, eager=True, constructor_name="create", fn="run", config={}, ), "C": SchemaNode( needs={"training_data": "A"}, uses=MyTestComponent, eager=True, constructor_name="create", fn="run", config={}, ), }) if is_train_graph: train_schema = schema else: predict_schema = schema with pytest.raises(GraphSchemaValidationException, match="Cycles"): validation.validate( GraphModelConfiguration( train_schema=train_schema, predict_schema=predict_schema, training_type=TrainingType.BOTH, language=None, core_target=None, nlu_target="nlu_target", ))
def test_graph_missing_package_requirements(required_packages: List[Text]): class MyComponent(TestComponentWithRun): @staticmethod def required_packages() -> List[Text]: """Any extra python dependencies required for this component to run.""" return required_packages graph_config = create_test_schema(uses=MyComponent) with pytest.raises(GraphSchemaValidationException, match="not installed"): validation.validate(graph_config, )