コード例 #1
0
def test_convert_tags_to_entities(
    entities: List[Dict[Text, Any]],
    keep: bool,
    expected_entities: List[Dict[Text, Any]],
):
    extractor = EntityExtractor()

    updated_entities = extractor.clean_up_entities(entities, keep)

    assert updated_entities == expected_entities
コード例 #2
0
ファイル: test_extractor.py プロジェクト: zuiwanting/rasa
def test_check_check_correct_entity_annotations(text: Text, warnings: int):
    reader = MarkdownReader()
    tokenizer = WhitespaceTokenizer()

    training_data = reader.reads(text)
    tokenizer.train(training_data)

    with pytest.warns(UserWarning) as record:
        EntityExtractor.check_correct_entity_annotations(training_data)

    assert len(record) == warnings
    assert all([excerpt in record[0].message.args[0]]
               for excerpt in ["Misaligned entity annotation in sentence"])
コード例 #3
0
ファイル: test_extractor.py プロジェクト: zuiwanting/rasa
def test_convert_tags_to_entities(
    text: Text,
    tags: Dict[Text, List[Text]],
    confidences: Dict[Text, List[float]],
    expected_entities: List[Dict[Text, Any]],
):
    extractor = EntityExtractor()
    tokenizer = WhitespaceTokenizer()

    message = Message(text)
    tokens = tokenizer.tokenize(message, TEXT)

    actual_entities = extractor.convert_predictions_into_entities(
        text, tokens, tags, confidences)
    assert actual_entities == expected_entities
コード例 #4
0
def test_clean_up_entities(
    text: Text,
    tokens: List[Token],
    entities: List[Dict[Text, Any]],
    keep: bool,
    expected_entities: List[Dict[Text, Any]],
):
    extractor = EntityExtractor()

    message = Message(text)
    message.set("tokens", tokens)

    updated_entities = extractor.clean_up_entities(message, entities, keep)

    assert updated_entities == expected_entities
コード例 #5
0
ファイル: test_extractor.py プロジェクト: attgua/Geco
def test_convert_tags_to_entities(
    text: Text,
    tags: Dict[Text, List[Text]],
    confidences: Dict[Text, List[float]],
    expected_entities: List[Dict[Text, Any]],
):
    extractor = EntityExtractor()
    tokenizer = WhitespaceTokenizer()

    message = Message(data={TEXT: text})
    tokens = tokenizer.tokenize(message, TEXT)

    split_entities_config = {SPLIT_ENTITIES_BY_COMMA: True}
    actual_entities = extractor.convert_predictions_into_entities(
        text, tokens, tags, split_entities_config, confidences)
    assert actual_entities == expected_entities
コード例 #6
0
ファイル: _ted_policy.py プロジェクト: praneethgb/rasa
    def _create_optional_event_for_entities(
        self,
        prediction_output: Dict[Text, tf.Tensor],
        is_e2e_prediction: bool,
        interpreter: NaturalLanguageInterpreter,
        tracker: DialogueStateTracker,
    ) -> Optional[List[Event]]:
        if tracker.latest_action_name != ACTION_LISTEN_NAME or not is_e2e_prediction:
            # entities belong only to the last user message
            # and only if user text was used for prediction,
            # a user message always comes after action listen
            return None

        if not self.config[ENTITY_RECOGNITION]:
            # entity recognition is not turned on, no entities can be predicted
            return None

        # The batch dimension of entity prediction is not the same as batch size,
        # rather it is the number of last (if max history featurizer else all)
        # text inputs in the batch
        # therefore, in order to pick entities from the latest user message
        # we need to pick entities from the last batch dimension of entity prediction
        predicted_tags, confidence_values = rasa.utils.train_utils.entity_label_to_tags(
            prediction_output,
            self._entity_tag_specs,
            self.config[BILOU_FLAG],
            prediction_index=-1,
        )

        if ENTITY_ATTRIBUTE_TYPE not in predicted_tags:
            # no entities detected
            return None

        # entities belong to the last message of the tracker
        # convert the predicted tags to actual entities
        text = tracker.latest_message.text
        parsed_message = interpreter.featurize_message(
            Message(data={TEXT: text}))
        tokens = parsed_message.get(TOKENS_NAMES[TEXT])
        entities = EntityExtractor.convert_predictions_into_entities(
            text,
            tokens,
            predicted_tags,
            self.split_entities_config,
            confidences=confidence_values,
        )

        # add the extractor name
        for entity in entities:
            entity[EXTRACTOR] = "TEDPolicy"

        return [EntitiesAdded(entities)]