def train(self, data): # type: (TrainingData) -> Interpreter """Trains the underlying pipeline by using the provided training data.""" # Before training the component classes, lets check if all arguments are provided if not self.skip_validation: components.validate_arguments(self.pipeline, self.config) self.training_data = data context = {} for component in self.pipeline: args = components.fill_args(component.pipeline_init_args(), context, self.config.as_dict()) updates = component.pipeline_init(*args) if updates: context.update(updates) init_context = context.copy() context["training_data"] = data for component in self.pipeline: args = components.fill_args(component.train_args(), context, self.config.as_dict()) updates = component.train(*args) if updates: context.update(updates) return Interpreter(self.pipeline, context=init_context, config=self.config.as_dict())
def parse(self, text, time): # type: (Text) -> Dict[Text, Any] """Parse the input text, classify it and return an object containing its intent and entities.""" if not text: # Not all components are able to handle empty strings. So we need to prevent that... # This default return will not contain all output attributes of all components, # but in the end, no one should pass an empty string in the first place. return self.default_output_attributes() current_context = self.context.copy() current_context.update(self.default_output_attributes()) current_context.update({"text": text, "time": time}) for component in self.pipeline: try: args = components.fill_args(component.process_args(), current_context, self.config) updates = component.process(*args) if updates: current_context.update(updates) except components.MissingArgumentError as e: raise Exception("Failed to parse at component '{}'. {}".format( component.name, e)) result = self.default_output_attributes() all_attributes = list( self.default_output_attributes().keys()) + self.output_attributes # Ensure only keys of `all_attributes` are present and no other keys are returned result.update({ key: current_context[key] for key in all_attributes if key in current_context }) return result
def load(meta, config, component_builder=None, skip_valdation=False): # type: (Metadata, RasaNLUConfig, Optional[ComponentBuilder], bool) -> Interpreter """Load a stored model and its components defined by the provided metadata.""" context = {"model_dir": meta.model_dir} if component_builder is None: # If no builder is passed, every interpreter creation will result in a new builder. # hence, no components are reused. component_builder = components.ComponentBuilder() model_config = config.as_dict() model_config.update(meta.metadata) pipeline = [] # Before instantiating the component classes, lets check if all required packages are available if not skip_valdation: components.validate_requirements(meta.pipeline) for component_name in meta.pipeline: component = component_builder.load_component( component_name, context, model_config, meta) try: args = components.fill_args(component.pipeline_init_args(), context, model_config) updates = component.pipeline_init(*args) if updates: context.update(updates) pipeline.append(component) except components.MissingArgumentError as e: raise Exception( "Failed to initialize component '{}'. {}".format( component.name, e.message)) return Interpreter(pipeline, context, model_config)
def parse(self, text, time=None): # type: (Text) -> Dict[Text, Any] """Parse the input text, classify it and return an object containing its intent and entities.""" if not text: # Not all components are able to handle empty strings. So we need to prevent that... # This default return will not contain all output attributes of all components, # but in the end, no one should pass an empty string in the first place. return self.default_output_attributes() current_context = self.context.copy() current_context.update(self.default_output_attributes()) current_context.update({ "text": text, "time": time }) for component in self.pipeline: try: args = components.fill_args(component.process_args(), current_context, self.config) updates = component.process(*args) if updates: current_context.update(updates) except components.MissingArgumentError as e: raise Exception("Failed to parse at component '{}'. {}".format(component.name, e)) result = self.default_output_attributes() all_attributes = list(self.default_output_attributes().keys()) + self.output_attributes # Ensure only keys of `all_attributes` are present and no other keys are returned result.update({key: current_context[key] for key in all_attributes if key in current_context}) return result
def train(self, data): # type: (TrainingData) -> Interpreter """Trains the underlying pipeline by using the provided training data.""" # Before training the component classes, lets check if all arguments are provided if not self.skip_validation: components.validate_arguments(self.pipeline, self.config) self.training_data = data context = {} # type: Dict[Text, Any] for component in self.pipeline: updates = component.pipeline_init() if updates: context.update(updates) init_context = context.copy() context["training_data"] = data for component in self.pipeline: args = components.fill_args(component.train_args(), context, self.config.as_dict()) logger.info("Starting to train component {}".format(component.name)) updates = component.train(*args) logger.info("Finished training component.") if updates: context.update(updates) return Interpreter(self.pipeline, context=init_context, config=self.config.as_dict())
def test_all_arguments_can_be_satisfied_during_init(component_class): # All available context arguments that will ever be generated during init component = component_class() context_arguments = {} for clz in registry.component_classes: for ctx_arg in clz.context_provides.get("pipeline_init", []): context_arguments[ctx_arg] = None filled_args = fill_args(component.pipeline_init_args(), context_arguments, config.DEFAULT_CONFIG) assert len(filled_args) == len(component.pipeline_init_args())
def test_all_arguments_can_be_satisfied_during_parse(component_class, default_config): """Check that `parse` method parameters can be filled filled from the context. Similar to `pipeline_init` test.""" # All available context arguments that will ever be generated during parse context_arguments = {"text": None, "time": None} for clz in registry.component_classes: for ctx_arg in clz.context_provides.get("pipeline_init", []): context_arguments[ctx_arg] = None for ctx_arg in clz.context_provides.get("process", []): context_arguments[ctx_arg] = None filled_args = fill_args(component_class.process_args(), context_arguments, default_config.as_dict()) assert len(filled_args) == len(component_class.process_args())
def test_all_arguments_can_be_satisfied_during_parse(component_class, default_config, component_builder): """Check that `parse` method parameters can be filled filled from the context. Similar to `pipeline_init` test.""" # All available context arguments that will ever be generated during parse component = component_builder.create_component(component_class.name, default_config) context_arguments = {"text": None} for clz in registry.component_classes: for ctx_arg in clz.context_provides.get("pipeline_init", []): context_arguments[ctx_arg] = None for ctx_arg in clz.context_provides.get("process", []): context_arguments[ctx_arg] = None filled_args = fill_args(component.process_args(), context_arguments, default_config.as_dict()) assert len(filled_args) == len(component.process_args())
def test_all_arguments_can_be_satisfied_during_train(component_class): # All available context arguments that will ever be generated during train # it might still happen, that in a certain pipeline configuration arguments can not be satisfied! component = component_class() context_arguments = {"training_data": None} for clz in registry.component_classes: for ctx_arg in clz.context_provides.get("pipeline_init", []): context_arguments[ctx_arg] = None for ctx_arg in clz.context_provides.get("train", []): context_arguments[ctx_arg] = None filled_args = fill_args(component.train_args(), context_arguments, config.DEFAULT_CONFIG) assert len(filled_args) == len(component.train_args())
def test_all_arguments_can_be_satisfied_during_train(component_class, default_config): """Check that `train` method parameters can be filled filled from the context. Similar to `pipeline_init` test.""" # All available context arguments that will ever be generated during train # it might still happen, that in a certain pipeline configuration arguments can not be satisfied! context_arguments = {"training_data": None} for clz in registry.component_classes: for ctx_arg in clz.context_provides.get("pipeline_init", []): context_arguments[ctx_arg] = None for ctx_arg in clz.context_provides.get("train", []): context_arguments[ctx_arg] = None filled_args = fill_args(component_class.train_args(), context_arguments, default_config.as_dict()) assert len(filled_args) == len(component_class.train_args())
def test_all_arguments_can_be_satisfied_during_train(component_class, default_config, component_builder): """Check that `train` method parameters can be filled filled from the context. Similar to `pipeline_init` test.""" # All available context arguments that will ever be generated during train # it might still happen, that in a certain pipeline configuration arguments can not be satisfied! component = component_builder.create_component(component_class.name, default_config) context_arguments = {"training_data": None} for clz in registry.component_classes: for ctx_arg in clz.context_provides.get("pipeline_init", []): context_arguments[ctx_arg] = None for ctx_arg in clz.context_provides.get("train", []): context_arguments[ctx_arg] = None filled_args = fill_args(component.train_args(), context_arguments, default_config.as_dict()) assert len(filled_args) == len(component.train_args())
def test_all_arguments_can_be_satisfied_during_init(component_class, default_config, component_builder): """Check that `pipeline_init` method parameters can be filled filled from the context. The parameters declared on the `pipeline_init` are not filled directly, rather the method is called via reflection. During the reflection, the parameters are filled from a so called context that is created when creating the pipeline and gets initialized with the configuration values. To make sure all arguments `pipeline_init` declares can be provided during the reflection, we do a 'dry run' where we check all parameters are part of the context.""" # All available context arguments that will ever be generated during init component = component_builder.create_component(component_class.name, default_config) context_arguments = {} for clz in registry.component_classes: for ctx_arg in clz.context_provides.get("pipeline_init", []): context_arguments[ctx_arg] = None filled_args = fill_args(component.pipeline_init_args(), context_arguments, default_config.as_dict()) assert len(filled_args) == len(component.pipeline_init_args())
def test_fill_args_with_unsatisfiable_param_from_config(): with pytest.raises(MissingArgumentError) as excinfo: fill_args(["good_one", "bad_one"], {}, {"good_one": 1}) assert "bad_one" in str(excinfo.value) assert "good_one" not in str(excinfo.value)