Ejemplo n.º 1
0
    def train(
        self,
        training_data: TrainingData,
        config: Optional[RasaNLUModelConfig] = None,
        **kwargs: Any,
    ) -> None:
        # checks whether there is at least one
        # example with an entity annotation
        if not training_data.entity_examples:
            logger.debug(
                "No training examples with entities present. Skip training"
                "of 'CRFEntityExtractor'.")
            return

        self.check_correct_entity_annotations(training_data)

        if self.component_config[BILOU_FLAG]:
            bilou_utils.apply_bilou_schema(training_data)

        # only keep the CRFs for tags we actually have training data for
        self._update_crf_order(training_data)

        # filter out pre-trained entity examples
        entity_examples = self.filter_trainable_entities(
            training_data.nlu_examples)

        dataset = [
            self._convert_to_crf_tokens(example) for example in entity_examples
        ]

        self._train_model(dataset)
Ejemplo n.º 2
0
    def train(self, training_data: TrainingData) -> Resource:
        """Trains the extractor on a data set."""
        # checks whether there is at least one
        # example with an entity annotation
        if not training_data.entity_examples:
            logger.debug(
                "No training examples with entities present. Skip training"
                "of 'CRFEntityExtractor'.")
            return self._resource

        self.check_correct_entity_annotations(training_data)

        if self.component_config[BILOU_FLAG]:
            bilou_utils.apply_bilou_schema(training_data)

        # only keep the CRFs for tags we actually have training data for
        self._update_crf_order(training_data)

        # filter out pre-trained entity examples
        entity_examples = self.filter_trainable_entities(
            training_data.nlu_examples)
        entity_examples = [
            message for message in entity_examples if message.features_present(
                attribute=TEXT,
                featurizers=self.component_config.get(FEATURIZERS))
        ]
        dataset = [
            self._convert_to_crf_tokens(example) for example in entity_examples
        ]

        self._train_model(dataset)

        self.persist()

        return self._resource
Ejemplo n.º 3
0
def test_apply_bilou_schema():
    tokenizer = WhitespaceTokenizer()

    message_1 = Message("Germany is part of the European Union")
    message_1.set(
        ENTITIES,
        [
            {"start": 0, "end": 7, "value": "Germany", "entity": "location"},
            {
                "start": 23,
                "end": 37,
                "value": "European Union",
                "entity": "organisation",
            },
        ],
    )

    message_2 = Message("Berlin is the capital of Germany")
    message_2.set(
        ENTITIES,
        [
            {"start": 0, "end": 6, "value": "Berlin", "entity": "location"},
            {"start": 25, "end": 32, "value": "Germany", "entity": "location"},
        ],
    )

    training_data = TrainingData([message_1, message_2])

    tokenizer.train(training_data)

    bilou_utils.apply_bilou_schema(training_data)

    assert message_1.get(BILOU_ENTITIES) == [
        "U-location",
        "O",
        "O",
        "O",
        "O",
        "B-organisation",
        "L-organisation",
        "O",
    ]
    assert message_2.get(BILOU_ENTITIES) == [
        "U-location",
        "O",
        "O",
        "O",
        "O",
        "U-location",
        "O",
    ]
Ejemplo n.º 4
0
def test_apply_bilou_schema(whitespace_tokenizer: WhitespaceTokenizerGraphComponent):

    message_1 = Message.build(
        text="Germany is part of the European Union", intent="inform"
    )
    message_1.set(
        ENTITIES,
        [
            {"start": 0, "end": 7, "value": "Germany", "entity": "location"},
            {
                "start": 23,
                "end": 37,
                "value": "European Union",
                "entity": "organisation",
            },
        ],
    )

    message_2 = Message.build(text="Berlin is the capital of Germany", intent="inform")
    message_2.set(
        ENTITIES,
        [
            {"start": 0, "end": 6, "value": "Berlin", "entity": "location"},
            {"start": 25, "end": 32, "value": "Germany", "entity": "location"},
        ],
    )

    training_data = TrainingData([message_1, message_2])

    whitespace_tokenizer.process_training_data(training_data)

    bilou_utils.apply_bilou_schema(training_data)

    assert message_1.get(BILOU_ENTITIES) == [
        "U-location",
        "O",
        "O",
        "O",
        "O",
        "B-organisation",
        "L-organisation",
    ]
    assert message_2.get(BILOU_ENTITIES) == [
        "U-location",
        "O",
        "O",
        "O",
        "O",
        "U-location",
    ]
Ejemplo n.º 5
0
    def preprocess_train_data(self,
                              training_data: TrainingData) -> RasaModelData:
        """Prepares data for training.

        Performs sanity checks on training data, extracts encodings for labels.
        """

        if self.component_config[BILOU_FLAG]:
            bilou_utils.apply_bilou_schema(training_data)

        label_id_index_mapping = self._label_id_index_mapping(training_data,
                                                              attribute=INTENT)

        if not label_id_index_mapping:
            # no labels are present to train
            return RasaModelData()

        self.index_label_id_mapping = self._invert_mapping(
            label_id_index_mapping)

        self._label_data = self._create_label_data(training_data,
                                                   label_id_index_mapping,
                                                   attribute=INTENT)

        tag_id_index_mapping = self._tag_id_index_mapping(training_data)
        self.index_tag_id_mapping = self._invert_mapping(tag_id_index_mapping)

        label_attribute = (
            INTENT if self.component_config[INTENT_CLASSIFICATION] else None)

        model_data = self._create_model_data(
            training_data.training_examples,
            label_id_index_mapping,
            tag_id_index_mapping,
            label_attribute=label_attribute,
        )

        self.num_tags = len(self.index_tag_id_mapping)

        self._check_input_dimension_consistency(model_data)

        return model_data