示例#1
0
def test_remove_bilou_prefixes():
    actual = bilou_utils.remove_bilou_prefixes(
        ["U-location", "O", "O", "O", "O", "B-organisation", "L-organisation"])

    assert actual == [
        "location", "O", "O", "O", "O", "organisation", "organisation"
    ]
示例#2
0
    def _predict_entities(self, predict_out: Optional[Dict[Text, tf.Tensor]],
                          message: Message) -> List[Dict]:
        if predict_out is None:
            return []

        # load tf graph and session
        predictions = predict_out["e_ids"].numpy()

        tags = [self.index_tag_id_mapping[p] for p in predictions[0]]

        if self.component_config[BILOU_FLAG]:
            tags = bilou_utils.remove_bilou_prefixes(tags)

        entities = self._convert_tags_to_entities(
            message.text, message.get(TOKENS_NAMES[TEXT], []), tags)

        extracted = self.add_extractor_name(entities)
        entities = message.get(ENTITIES, []) + extracted

        return entities
示例#3
0
    def _tag_confidences(
        self, tokens: List[Token], predictions: Dict[Text, List[Dict[Text, float]]]
    ) -> Tuple[Dict[Text, List[Text]], Dict[Text, List[float]]]:
        """Get most likely tag predictions with confidence values for tokens."""
        tags = {}
        confidences = {}

        for tag_name, predicted_tags in predictions.items():
            if len(tokens) != len(predicted_tags):
                raise Exception(
                    "Inconsistency in amount of tokens between crfsuite and message"
                )

            _tags, _confidences = self._most_likely_tag(predicted_tags)

            if self.component_config[BILOU_FLAG]:
                _tags = bilou_utils.ensure_consistent_bilou_tagging(_tags)
                _tags = bilou_utils.remove_bilou_prefixes(_tags)

            confidences[tag_name] = _confidences
            tags[tag_name] = _tags

        return tags, confidences