def _validate_constructor( node: SchemaNode, create_fn_params: Dict[Text, ParameterInfo], ) -> None: _validate_types_of_reserved_keywords(create_fn_params, node, node.constructor_name) required_args = _required_args(create_fn_params) if required_args and node.eager: raise GraphSchemaValidationException( f"Your model uses a component '{node.uses.__name__}' which has a " f"method '{node.constructor_name}' which has required parameters " f"('{', '.join(required_args)}'). " f"Extra parameters can only be supplied to the constructor method which is " f"used during training." f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.") for param_name in _required_args(create_fn_params): if not node.eager and param_name not in node.needs: raise GraphSchemaValidationException( f"Your model uses a component '{node.uses.__name__}' which " f"needs the param '{param_name}' to be provided to its method " f"'{node.constructor_name}'. Please make sure that you registered " f"your component correctly and and that your model configuration is " f"valid." f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.")
def _validate_target( target_name: Text, target_type: Text, expected_type: Type, schema: GraphSchema, ) -> None: if target_name not in schema.nodes: raise GraphSchemaValidationException( f"Graph schema specifies invalid {target_type} target '{target_name}'. " f"Please make sure specify a valid node name as target.") if any(target_name in node.needs.values() for node in schema.nodes.values()): raise GraphSchemaValidationException( f"One graph node uses the {target_type} target '{target_name}' as input. " f"This is not allowed as NLU prediction and Core prediction are run " f"separately.") target_node = schema.nodes[target_name] _, target_return_type = _get_parameter_information(target_node.uses, target_node.fn) if not typing_utils.issubtype(target_return_type, expected_type): raise GraphSchemaValidationException( f"Your {target_type} model's output component " f"'{target_node.uses.__name__}' returns an invalid return " f"type '{target_return_type}'. This is not allowed. The {target_type} " f"model's last component is expected to return the type '{expected_type}'. " f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.")
def _validate_needs( node_name: Text, node: SchemaNode, graph: GraphSchema, create_fn_params: Dict[Text, ParameterInfo], run_fn_params: Dict[Text, ParameterInfo], ) -> None: available_args, has_kwargs = _get_available_args(node, create_fn_params, run_fn_params) for param_name, parent_name in node.needs.items(): if not has_kwargs and param_name not in available_args: raise GraphSchemaValidationException( f"Node '{node_name}' is configured to retrieve a value for the " f"param '{param_name}' by its parent node '{parent_name}' although " f"its method '{node.fn}' does not accept a parameter with this " f"name. Please make sure your node's 'needs' section is " f"correctly specified.") parent = graph.nodes[parent_name] required_type = available_args.get(param_name) needs_passed_to_kwargs = has_kwargs and required_type is None if not needs_passed_to_kwargs: _validate_parent_return_type(node_name, node, parent_name, parent, required_type.type_annotation)
def _validate_run_fn_return_type(node_name: Text, node: SchemaNode, return_type: Type, is_training: bool) -> None: if return_type == inspect.Parameter.empty: raise GraphSchemaValidationException( f"Node '{node_name}' uses a component '{node.uses.__name__}' whose " f"method '{node.fn}' does not have a type annotation for " f"its return value. Type annotations are required for all graph " f"components to validate the graph's structure.") if is_training and not isinstance(return_type, Fingerprintable): raise GraphSchemaValidationException( f"Node '{node_name}' uses a component '{node.uses.__name__}' whose method " f"'{node.fn}' does not return a fingerprintable " f"output. This is required for caching. Please make sure you're " f"using a return type which implements the " f"'{Fingerprintable.__name__}' protocol.")
def _validate_required_packages(node: SchemaNode) -> None: missing_packages = rasa.utils.common.find_unavailable_packages( node.uses.required_packages()) if missing_packages: raise GraphSchemaValidationException( f"Component '{node.uses.__name__}' requires the following packages which " f"are currently not installed: {', '.join(missing_packages)}.")
def _validate_required_components(schema: GraphSchema, ) -> None: unmet_requirements: Dict[Type, Set[Text]] = dict() for target_name in schema.target_names: unmet_requirements_for_target, _ = _recursively_check_required_components( node_name=target_name, schema=schema, ) for component_type, node_names in unmet_requirements_for_target.items( ): unmet_requirements.setdefault(component_type, set()).update(node_names) if unmet_requirements: errors = "\n".join([ f"The following components require a {component_type.__name__}: " f"{', '.join(sorted(required_by))}. " for component_type, required_by in unmet_requirements.items() ]) num_nodes = len( set(node_name for required_by in unmet_requirements.values() for node_name in required_by)) raise GraphSchemaValidationException( f"{num_nodes} components are missing required components which have to " f"run before themselves:\n" f"{errors}" f"Please add the required components to your model configuration.")
def _validate_interface_usage(node_name: Text, node: SchemaNode) -> None: if not issubclass(node.uses, GraphComponent): raise GraphSchemaValidationException( f"Node '{node_name}' uses class '{node.uses.__name__}'. This class does " f"not implement the '{GraphComponent.__name__}' interface and can " f"hence not be used within the graph. Please use a different " f"component or implement the '{GraphComponent}' interface for " f"'{node.uses.__name__}'.")
def _validate_needs( node: SchemaNode, graph: GraphSchema, create_fn_params: Dict[Text, ParameterInfo], run_fn_params: Dict[Text, ParameterInfo], ) -> None: available_args, has_kwargs = _get_available_args(node, create_fn_params, run_fn_params) for param_name, parent_name in node.needs.items(): if not has_kwargs and param_name not in available_args: raise GraphSchemaValidationException( f"Your model uses a component '{node.uses.__name__}' which is " f"supposed to retrieve a value for the " f"param '{param_name}' although " f"its method '{node.fn}' does not accept a parameter with this " f"name. Please make sure that you registered " f"your component correctly and and that your model configuration is " f"valid." f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.") if not _is_placeholder_input( parent_name) and parent_name not in graph.nodes: raise GraphSchemaValidationException( f"Your model uses a component '{node.uses.__name__}' which expects " f"input from a previous component but this component is not part of " f"your model configuration. Please make sure that you registered " f"your component correctly and and that your model configuration is " f"valid." f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.") required_type = available_args.get(param_name) needs_passed_to_kwargs = has_kwargs and required_type is None if not needs_passed_to_kwargs: parent = None if _is_placeholder_input(parent_name): parent_return_type = RESERVED_PLACEHOLDERS[parent_name] else: parent = graph.nodes[parent_name] _, parent_return_type = _get_parameter_information( parent.uses, parent.fn) _validate_parent_return_type(node, parent, parent_return_type, required_type.type_annotation)
def _validate_interface_usage(node: SchemaNode) -> None: if not issubclass(node.uses, GraphComponent): raise GraphSchemaValidationException( f"Your model uses a component with class '{node.uses.__name__}'. " f"This class does not implement the '{GraphComponent.__name__}' interface " f"and can hence not be run within Rasa Open Source. Please use a different " f"component or implement the '{GraphComponent}' interface in class " f"'{node.uses.__name__}'. " f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.")
def _get_fn(uses: Type[GraphComponent], method_name: Text) -> Callable: fn = getattr(uses, method_name, None) if fn is None: raise GraphSchemaValidationException( f"Your model uses a graph component '{uses.__name__}' which does not " f"have the required " f"method '{method_name}'. Please make sure you're either using " f"the right component or that your component is registered with the " f"correct component type." f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.") return fn
def _get_fn(node_name: Text, uses: Type[GraphComponent], method_name: Text) -> Callable: fn = getattr(uses, method_name, None) if fn is None: raise GraphSchemaValidationException( f"Node '{node_name}' uses graph component '{uses.__name__}' which does not " f"have the specified " f"method '{method_name}'. Please make sure you're either using " f"the right graph component or specifying a valid method " f"for the 'fn' and 'constructor_name' options.") return fn
def _validate_run_fn_return_type(node: SchemaNode, return_type: Type, is_training: bool) -> None: if return_type == inspect.Parameter.empty: raise GraphSchemaValidationException( f"Your model uses a component '{node.uses.__name__}' whose " f"method '{node.fn}' does not have a type annotation for " f"its return value. Type annotations are required for all " f"components to validate your model's structure." f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.") # TODO: Handle forward references here if typing_utils.issubtype(return_type, list): return_type = typing_utils.get_args(return_type)[0] if is_training and not isinstance(return_type, Fingerprintable): raise GraphSchemaValidationException( f"Your model uses a component '{node.uses.__name__}' whose method " f"'{node.fn}' does not return a fingerprintable " f"output. This is required for proper caching between model trainings. " f"Please make sure you're using a return type which implements the " f"'{Fingerprintable.__name__}' protocol.")
def _validate_supported_languages(language: Optional[Text], node: SchemaNode) -> None: supported_languages = node.uses.supported_languages() not_supported_languages = node.uses.not_supported_languages() if supported_languages and not_supported_languages: raise RasaException( "Only one of `supported_languages` and " "`not_supported_languages` can return a value different from `None`." ) if (language and supported_languages is not None and language not in supported_languages): raise GraphSchemaValidationException( f"The component '{node.uses.__name__}' does not support the currently " f"specified language '{language}'.") if (language and not_supported_languages is not None and language in not_supported_languages): raise GraphSchemaValidationException( f"The component '{node.uses.__name__}' does not support the currently " f"specified language '{language}'.")
def _validate_types_of_reserved_keywords(params: Dict[Text, ParameterInfo], node_name: Text, node: SchemaNode, fn_name: Text) -> None: for param_name, param in params.items(): if param_name in KEYWORDS_EXPECTED_TYPES: if not typing_utils.issubtype(param.type_annotation, KEYWORDS_EXPECTED_TYPES[param_name]): raise GraphSchemaValidationException( f"Node '{node_name}' uses a graph component " f"'{node.uses.__name__}' which has an incompatible type " f"'{param.type_annotation}' for the '{param_name}' parameter in " f"its '{fn_name}' method. Expected type " f"'{ KEYWORDS_EXPECTED_TYPES[param_name]}'.")
def _validate_prediction_targets(schema: GraphSchema, core_target: Optional[Text], nlu_target: Text) -> None: if not nlu_target: raise GraphSchemaValidationException( "Graph schema specifies no target for the 'nlu_target'. It is required " "for a prediction graph to specify this. Please choose a valid node " "name for this.") _validate_target(nlu_target, "NLU", List[Message], schema) if core_target: _validate_target(core_target, "Core", PolicyPrediction, schema)
def _walk_and_check_for_cycles(visited_so_far: List[Text], node_name: Text, schema: GraphSchema) -> None: if node_name in visited_so_far: raise GraphSchemaValidationException( f"Node '{node_name}' has itself as dependency. Cycles are not allowed " f"in the graph. Please make sure that '{node_name}' does not have itself" f"specified in 'needs' and none of '{node_name}'s dependencies have " f"'{node_name}' specified in 'needs'.") parents = schema.nodes[node_name].needs.values() for parent_name in parents: _walk_and_check_for_cycles([*visited_so_far, node_name], parent_name, schema)
def _validate_types_of_reserved_keywords(params: Dict[Text, ParameterInfo], node: SchemaNode, fn_name: Text) -> None: for param_name, param in params.items(): if param_name in KEYWORDS_EXPECTED_TYPES: if not typing_utils.issubtype(param.type_annotation, KEYWORDS_EXPECTED_TYPES[param_name]): raise GraphSchemaValidationException( f"Your model uses a component '{node.uses.__name__}' which has an " f"incompatible type '{param.type_annotation}' for " f"the '{param_name}' parameter in its '{fn_name}' method. " f"The expected type is '{KEYWORDS_EXPECTED_TYPES[param_name]}'." f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.")
def _validate_parent_return_type( node_name: Text, node: SchemaNode, parent_name: Text, parent: SchemaNode, required_type: TypeAnnotation, ) -> None: _, parent_return_type = _get_parameter_information(parent_name, parent.uses, parent.fn) if not typing_utils.issubtype(parent_return_type, required_type): raise GraphSchemaValidationException( f"Parent of node '{node_name}' returns type " f"'{parent_return_type}' but type '{required_type}' " f"was expected by component '{node.uses.__name__}'.")
def _validate_constructor( node_name: Text, node: SchemaNode, create_fn_params: Dict[Text, ParameterInfo], ) -> None: _validate_types_of_reserved_keywords(create_fn_params, node_name, node, node.constructor_name) required_args = _required_args(create_fn_params) if required_args and node.eager: raise GraphSchemaValidationException( f"Node '{node_name}' has a constructor which has required " f"required parameters ('{', '.join(required_args)}'). " f"Extra parameters can only supplied to be the constructor if the node " f"is being run in lazy mode.") for param_name in _required_args(create_fn_params): if not node.eager and param_name not in node.needs: raise GraphSchemaValidationException( f"Node '{node_name}' uses a component '{node.uses.__name__}' which " f"needs the param '{param_name}' to be provided to its method " f"'{node.constructor_name}'. Please make sure to specify the " f"parameter in the node's 'needs' section.")
def _get_type_hints(node_name: Text, uses: Type[GraphComponent], fn: Callable) -> Dict[Text, TypeAnnotation]: try: return typing.get_type_hints(fn) except NameError as e: logging.debug(f"Failed to retrieve type annotations for component " f"'{uses.__name__}' due to error:\n{e}") raise GraphSchemaValidationException( f"Node '{node_name}' uses graph component '{uses.__name__}' which has " f"type annotations in its method '{fn.__name__}' which failed to be " f"retrieved. Please make sure remove any forward " f"reference by removing the quotes around the type " f"(e.g. 'def foo() -> \"int\"' becomes 'def foo() -> int'. and make sure " f"all type annotations can be resolved during runtime. Note that you might " f"need to do a 'from __future__ import annotations' to avoid forward " f"references.")
def _validate_run_fn( node: SchemaNode, run_fn_params: Dict[Text, ParameterInfo], run_fn_return_type: TypeAnnotation, is_train_graph: bool, ) -> None: _validate_types_of_reserved_keywords(run_fn_params, node, node.fn) _validate_run_fn_return_type(node, run_fn_return_type, is_train_graph) for param_name in _required_args(run_fn_params): if param_name not in node.needs: raise GraphSchemaValidationException( f"Your model uses a component '{node.uses.__name__}' which " f"needs the param '{param_name}' to be provided to its method " f"'{node.fn}'. Please make sure that you registered " f"your component correctly and and that your model configuration is " f"valid." f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.")
def _validate_run_fn( node_name: Text, node: SchemaNode, run_fn_params: Dict[Text, ParameterInfo], run_fn_return_type: TypeAnnotation, is_train_graph: bool, ) -> None: _validate_types_of_reserved_keywords(run_fn_params, node_name, node, node.fn) _validate_run_fn_return_type(node_name, node, run_fn_return_type, is_train_graph) for param_name in _required_args(run_fn_params): if param_name not in node.needs: raise GraphSchemaValidationException( f"Node '{node_name}' uses a component '{node.uses.__name__}' which " f"needs the param '{param_name}' to be provided to its method " f"'{node.fn}'. Please make sure to specify the parameter in " f"the node's 'needs' section.")
def _validate_parent_return_type( node: SchemaNode, parent_node: Optional[SchemaNode], parent_return_type: TypeAnnotation, required_type: TypeAnnotation, ) -> None: if not typing_utils.issubtype(parent_return_type, required_type): parent_node_text = "" if parent_node: parent_node_text = f" by the component '{parent_node.uses.__name__}'" raise GraphSchemaValidationException( f"Your component '{node.uses.__name__}' expects an input of type " f"'{required_type}' but it receives an input of type '{parent_return_type}'" f"{parent_node_text}. " f"Please make sure that you registered " f"your component correctly and and that your model configuration is " f"valid." f"See {DOCS_URL_GRAPH_COMPONENTS} for more information.")