Exemple #1
0
    def train(self, data):
        # type: (TrainingData) -> Interpreter
        """Trains the underlying pipeline by using the provided training data."""

        # Before training the component classes, lets check if all arguments are provided
        if not self.skip_validation:
            components.validate_arguments(self.pipeline, self.config)

        self.training_data = data

        context = {}

        for component in self.pipeline:
            args = components.fill_args(component.pipeline_init_args(),
                                        context, self.config.as_dict())
            updates = component.pipeline_init(*args)
            if updates:
                context.update(updates)

        init_context = context.copy()

        context["training_data"] = data

        for component in self.pipeline:
            args = components.fill_args(component.train_args(), context,
                                        self.config.as_dict())
            updates = component.train(*args)
            if updates:
                context.update(updates)

        return Interpreter(self.pipeline,
                           context=init_context,
                           config=self.config.as_dict())
Exemple #2
0
    def parse(self, text, time):
        # type: (Text) -> Dict[Text, Any]
        """Parse the input text, classify it and return an object containing its intent and entities."""

        if not text:
            # Not all components are able to handle empty strings. So we need to prevent that...
            # This default return will not contain all output attributes of all components,
            # but in the end, no one should pass an empty string in the first place.
            return self.default_output_attributes()

        current_context = self.context.copy()
        current_context.update(self.default_output_attributes())

        current_context.update({"text": text, "time": time})

        for component in self.pipeline:
            try:
                args = components.fill_args(component.process_args(),
                                            current_context, self.config)
                updates = component.process(*args)
                if updates:
                    current_context.update(updates)
            except components.MissingArgumentError as e:
                raise Exception("Failed to parse at component '{}'. {}".format(
                    component.name, e))

        result = self.default_output_attributes()
        all_attributes = list(
            self.default_output_attributes().keys()) + self.output_attributes
        # Ensure only keys of `all_attributes` are present and no other keys are returned
        result.update({
            key: current_context[key]
            for key in all_attributes if key in current_context
        })
        return result
Exemple #3
0
    def load(meta, config, component_builder=None, skip_valdation=False):
        # type: (Metadata, RasaNLUConfig, Optional[ComponentBuilder], bool) -> Interpreter
        """Load a stored model and its components defined by the provided metadata."""
        context = {"model_dir": meta.model_dir}
        if component_builder is None:
            # If no builder is passed, every interpreter creation will result in a new builder.
            # hence, no components are reused.
            component_builder = components.ComponentBuilder()

        model_config = config.as_dict()
        model_config.update(meta.metadata)

        pipeline = []

        # Before instantiating the component classes, lets check if all required packages are available
        if not skip_valdation:
            components.validate_requirements(meta.pipeline)

        for component_name in meta.pipeline:
            component = component_builder.load_component(
                component_name, context, model_config, meta)
            try:
                args = components.fill_args(component.pipeline_init_args(),
                                            context, model_config)
                updates = component.pipeline_init(*args)
                if updates:
                    context.update(updates)
                pipeline.append(component)
            except components.MissingArgumentError as e:
                raise Exception(
                    "Failed to initialize component '{}'. {}".format(
                        component.name, e.message))

        return Interpreter(pipeline, context, model_config)
Exemple #4
0
    def parse(self, text, time=None):
        # type: (Text) -> Dict[Text, Any]
        """Parse the input text, classify it and return an object containing its intent and entities."""

        if not text:
            # Not all components are able to handle empty strings. So we need to prevent that...
            # This default return will not contain all output attributes of all components,
            # but in the end, no one should pass an empty string in the first place.
            return self.default_output_attributes()

        current_context = self.context.copy()
        current_context.update(self.default_output_attributes())

        current_context.update({
            "text": text,
            "time": time
        })

        for component in self.pipeline:
            try:
                args = components.fill_args(component.process_args(), current_context, self.config)
                updates = component.process(*args)
                if updates:
                    current_context.update(updates)
            except components.MissingArgumentError as e:
                raise Exception("Failed to parse at component '{}'. {}".format(component.name, e))

        result = self.default_output_attributes()
        all_attributes = list(self.default_output_attributes().keys()) + self.output_attributes
        # Ensure only keys of `all_attributes` are present and no other keys are returned
        result.update({key: current_context[key] for key in all_attributes if key in current_context})
        return result
Exemple #5
0
    def train(self, data):
        # type: (TrainingData) -> Interpreter
        """Trains the underlying pipeline by using the provided training data."""

        # Before training the component classes, lets check if all arguments are provided
        if not self.skip_validation:
            components.validate_arguments(self.pipeline, self.config)

        self.training_data = data

        context = {}        # type: Dict[Text, Any]

        for component in self.pipeline:
            updates = component.pipeline_init()
            if updates:
                context.update(updates)

        init_context = context.copy()

        context["training_data"] = data

        for component in self.pipeline:
            args = components.fill_args(component.train_args(), context, self.config.as_dict())
            logger.info("Starting to train component {}".format(component.name))
            updates = component.train(*args)
            logger.info("Finished training component.")
            if updates:
                context.update(updates)

        return Interpreter(self.pipeline, context=init_context, config=self.config.as_dict())
Exemple #6
0
def test_all_arguments_can_be_satisfied_during_init(component_class):
    # All available context arguments that will ever be generated during init
    component = component_class()
    context_arguments = {}
    for clz in registry.component_classes:
        for ctx_arg in clz.context_provides.get("pipeline_init", []):
            context_arguments[ctx_arg] = None

    filled_args = fill_args(component.pipeline_init_args(), context_arguments,
                            config.DEFAULT_CONFIG)
    assert len(filled_args) == len(component.pipeline_init_args())
def test_all_arguments_can_be_satisfied_during_parse(component_class, default_config):
    """Check that `parse` method parameters can be filled filled from the context. Similar to `pipeline_init` test."""

    # All available context arguments that will ever be generated during parse
    context_arguments = {"text": None, "time": None}
    for clz in registry.component_classes:
        for ctx_arg in clz.context_provides.get("pipeline_init", []):
            context_arguments[ctx_arg] = None
        for ctx_arg in clz.context_provides.get("process", []):
            context_arguments[ctx_arg] = None

    filled_args = fill_args(component_class.process_args(), context_arguments, default_config.as_dict())
    assert len(filled_args) == len(component_class.process_args())
Exemple #8
0
def test_all_arguments_can_be_satisfied_during_parse(component_class, default_config, component_builder):
    """Check that `parse` method parameters can be filled filled from the context. Similar to `pipeline_init` test."""

    # All available context arguments that will ever be generated during parse
    component = component_builder.create_component(component_class.name, default_config)
    context_arguments = {"text": None}
    for clz in registry.component_classes:
        for ctx_arg in clz.context_provides.get("pipeline_init", []):
            context_arguments[ctx_arg] = None
        for ctx_arg in clz.context_provides.get("process", []):
            context_arguments[ctx_arg] = None

    filled_args = fill_args(component.process_args(), context_arguments, default_config.as_dict())
    assert len(filled_args) == len(component.process_args())
Exemple #9
0
def test_all_arguments_can_be_satisfied_during_train(component_class):
    # All available context arguments that will ever be generated during train
    # it might still happen, that in a certain pipeline configuration arguments can not be satisfied!
    component = component_class()
    context_arguments = {"training_data": None}
    for clz in registry.component_classes:
        for ctx_arg in clz.context_provides.get("pipeline_init", []):
            context_arguments[ctx_arg] = None
        for ctx_arg in clz.context_provides.get("train", []):
            context_arguments[ctx_arg] = None

    filled_args = fill_args(component.train_args(), context_arguments,
                            config.DEFAULT_CONFIG)
    assert len(filled_args) == len(component.train_args())
Exemple #10
0
def test_all_arguments_can_be_satisfied_during_train(component_class, default_config):
    """Check that `train` method parameters can be filled filled from the context. Similar to `pipeline_init` test."""

    # All available context arguments that will ever be generated during train
    # it might still happen, that in a certain pipeline configuration arguments can not be satisfied!
    context_arguments = {"training_data": None}
    for clz in registry.component_classes:
        for ctx_arg in clz.context_provides.get("pipeline_init", []):
            context_arguments[ctx_arg] = None
        for ctx_arg in clz.context_provides.get("train", []):
            context_arguments[ctx_arg] = None

    filled_args = fill_args(component_class.train_args(), context_arguments, default_config.as_dict())
    assert len(filled_args) == len(component_class.train_args())
Exemple #11
0
def test_all_arguments_can_be_satisfied_during_train(component_class, default_config, component_builder):
    """Check that `train` method parameters can be filled filled from the context. Similar to `pipeline_init` test."""

    # All available context arguments that will ever be generated during train
    # it might still happen, that in a certain pipeline configuration arguments can not be satisfied!
    component = component_builder.create_component(component_class.name, default_config)
    context_arguments = {"training_data": None}
    for clz in registry.component_classes:
        for ctx_arg in clz.context_provides.get("pipeline_init", []):
            context_arguments[ctx_arg] = None
        for ctx_arg in clz.context_provides.get("train", []):
            context_arguments[ctx_arg] = None

    filled_args = fill_args(component.train_args(), context_arguments, default_config.as_dict())
    assert len(filled_args) == len(component.train_args())
Exemple #12
0
def test_all_arguments_can_be_satisfied_during_init(component_class, default_config, component_builder):
    """Check that `pipeline_init` method parameters can be filled filled from the context.

    The parameters declared on the `pipeline_init` are not filled directly, rather the method is called via reflection.
    During the reflection, the parameters are filled from a so called context that is created when creating the
    pipeline and gets initialized with the configuration values. To make sure all arguments `pipeline_init` declares
    can be provided during the reflection, we do a 'dry run' where we check all parameters are part of the context."""

    # All available context arguments that will ever be generated during init
    component = component_builder.create_component(component_class.name, default_config)
    context_arguments = {}
    for clz in registry.component_classes:
        for ctx_arg in clz.context_provides.get("pipeline_init", []):
            context_arguments[ctx_arg] = None

    filled_args = fill_args(component.pipeline_init_args(), context_arguments, default_config.as_dict())
    assert len(filled_args) == len(component.pipeline_init_args())
Exemple #13
0
def test_fill_args_with_unsatisfiable_param_from_config():
    with pytest.raises(MissingArgumentError) as excinfo:
        fill_args(["good_one", "bad_one"], {}, {"good_one": 1})
    assert "bad_one" in str(excinfo.value)
    assert "good_one" not in str(excinfo.value)
Exemple #14
0
def test_fill_args_with_unsatisfiable_param_from_config():
    with pytest.raises(MissingArgumentError) as excinfo:
        fill_args(["good_one", "bad_one"], {}, {"good_one": 1})
    assert "bad_one" in str(excinfo.value)
    assert "good_one" not in str(excinfo.value)