def get_slots(self, text):
        """Extracts slots from the provided text

        Returns:
            list of dict: The list of extracted slots

        Raises:
            NotTrained: When the slot filler is not fitted
        """
        if not self.slot_name_mapping:
            # Early return if the intent has no slots
            return []

        tokens = tokenize(text, self.language)
        if not tokens:
            return []
        features = self.compute_features(tokens)
        tags = [
            _decode_tag(tag) for tag in self.crf_model.predict_single(features)
        ]
        slots = tags_to_slots(text, tokens, tags, self.config.tagging_scheme,
                              self.slot_name_mapping)

        builtin_slots_names = set(
            slot_name
            for (slot_name, entity) in iteritems(self.slot_name_mapping)
            if is_builtin_entity(entity))
        if not builtin_slots_names:
            return slots

        # Replace tags corresponding to builtin entities by outside tags
        tags = _replace_builtin_tags(tags, builtin_slots_names)
        return self._augment_slots(text, tokens, tags, builtin_slots_names)
    def get_slots(self, text):
        """Extracts slots from the provided text

        Returns:
            list of dict: The list of extracted slots

        Raises:
            NotTrained: When the slot filler is not fitted
        """
        if not self.fitted:
            raise NotTrained("CRFSlotFiller must be fitted")
        tokens = tokenize(text, self.language)
        if not tokens:
            return []
        features = self.compute_features(tokens)
        tags = [_decode_tag(tag) for tag in
                self.crf_model.predict_single(features)]
        slots = tags_to_slots(text, tokens, tags, self.config.tagging_scheme,
                              self.slot_name_mapping)

        builtin_slots_names = set(slot_name for (slot_name, entity) in
                                  iteritems(self.slot_name_mapping)
                                  if is_builtin_entity(entity))
        if not builtin_slots_names:
            return slots

        # Replace tags corresponding to builtin entities by outside tags
        tags = _replace_builtin_tags(tags, builtin_slots_names)
        return self._augment_slots(text, tokens, tags, builtin_slots_names)
    def get_slots(self, text):
        """Extracts slots from the provided text

        Returns:
            list of dict: The list of extracted slots

        Raises:
            NotTrained: When the slot filler is not fitted
        """
        if not self.slot_name_mapping:
            # Early return if the intent has no slots
            return []

        tokens = tokenize(text, self.language)
        if not tokens:
            return []
        features = self.compute_features(tokens)
        tags = self.crf_model.predict_single(features)
        logger.debug(
            DifferedLoggingMessage(self.log_inference_weights,
                                   text,
                                   tokens=tokens,
                                   features=features,
                                   tags=tags))
        decoded_tags = [_decode_tag(t) for t in tags]
        return tags_to_slots(text, tokens, decoded_tags,
                             self.config.tagging_scheme,
                             self.slot_name_mapping)
    def _augment_slots(self, text, tokens, tags, builtin_slots_names):
        scope = set(self.slot_name_mapping[slot]
                    for slot in builtin_slots_names)
        builtin_entities = [
            be for entity_kind in scope
            for be in self.builtin_entity_parser.parse(
                text, scope=[entity_kind], use_cache=True)
        ]
        # We remove builtin entities which conflicts with custom slots
        # extracted by the CRF
        builtin_entities = _filter_overlapping_builtins(
            builtin_entities, tokens, tags, self.config.tagging_scheme)

        # We resolve conflicts between builtin entities by keeping the longest
        # matches. In case when two builtin entities span the same range, we
        # keep both.
        builtin_entities = _disambiguate_builtin_entities(builtin_entities)

        # We group builtin entities based on their position
        grouped_entities = (list(bes) for _, bes in groupby(
            builtin_entities, key=lambda s: s[RES_MATCH_RANGE][START]))
        grouped_entities = sorted(
            grouped_entities,
            key=lambda entities: entities[0][RES_MATCH_RANGE][START])

        features = self.compute_features(tokens)
        spans_ranges = [
            entities[0][RES_MATCH_RANGE] for entities in grouped_entities
        ]
        tokens_indexes = _spans_to_tokens_indexes(spans_ranges, tokens)

        # We loop on all possible slots permutations and use the CRF to find
        # the best one in terms of probability
        slots_permutations = _get_slots_permutations(grouped_entities,
                                                     self.slot_name_mapping)
        best_updated_tags = tags
        best_permutation_score = -1
        for slots in slots_permutations:
            updated_tags = copy(tags)
            for slot_index, slot in enumerate(slots):
                indexes = tokens_indexes[slot_index]
                sub_tags_sequence = positive_tagging(
                    self.config.tagging_scheme, slot, len(indexes))
                updated_tags[indexes[0]:indexes[-1] + 1] = sub_tags_sequence
            score = self._get_sequence_probability(features, updated_tags)
            if score > best_permutation_score:
                best_updated_tags = updated_tags
                best_permutation_score = score
        slots = tags_to_slots(text, tokens, best_updated_tags,
                              self.config.tagging_scheme,
                              self.slot_name_mapping)

        return _reconciliate_builtin_slots(text, slots, builtin_entities)
    def _augment_slots(self, text, tokens, tags, builtin_slots_names):
        scope = set(self.slot_name_mapping[slot]
                    for slot in builtin_slots_names)
        builtin_entities = [be for entity_kind in scope
                            for be in get_builtin_entities(text, self.language,
                                                           [entity_kind])]
        # We remove builtin entities which conflicts with custom slots
        # extracted by the CRF
        builtin_entities = _filter_overlapping_builtins(
            builtin_entities, tokens, tags, self.config.tagging_scheme)

        # We resolve conflicts between builtin entities by keeping the longest
        # matches. In case when two builtin entities span the same range, we
        # keep both.
        builtin_entities = _disambiguate_builtin_entities(builtin_entities)

        # We group builtin entities based on their position
        grouped_entities = (
            list(bes)
            for _, bes in groupby(builtin_entities,
                                  key=lambda s: s[RES_MATCH_RANGE][START]))
        grouped_entities = sorted(
            grouped_entities,
            key=lambda entities: entities[0][RES_MATCH_RANGE][START])

        features = self.compute_features(tokens)
        spans_ranges = [entities[0][RES_MATCH_RANGE]
                        for entities in grouped_entities]
        tokens_indexes = _spans_to_tokens_indexes(spans_ranges, tokens)

        # We loop on all possible slots permutations and use the CRF to find
        # the best one in terms of probability
        slots_permutations = _get_slots_permutations(
            grouped_entities, self.slot_name_mapping)
        best_updated_tags = tags
        best_permutation_score = -1
        for slots in slots_permutations:
            updated_tags = copy(tags)
            for slot_index, slot in enumerate(slots):
                indexes = tokens_indexes[slot_index]
                sub_tags_sequence = positive_tagging(
                    self.config.tagging_scheme, slot, len(indexes))
                updated_tags[indexes[0]:indexes[-1] + 1] = sub_tags_sequence
            score = self._get_sequence_probability(features, updated_tags)
            if score > best_permutation_score:
                best_updated_tags = updated_tags
                best_permutation_score = score
        slots = tags_to_slots(text, tokens, best_updated_tags,
                              self.config.tagging_scheme,
                              self.slot_name_mapping)

        return _reconciliate_builtin_slots(text, slots, builtin_entities)
Exemple #6
0
    def _augment_slots(self, text, tokens, tags, builtin_slots_names):
        augmented_tags = tags
        scope = [self.slot_name_mapping[slot] for slot in builtin_slots_names]
        builtin_entities = get_builtin_entities(text, self.language, scope)

        builtin_entities = _filter_overlapping_builtins(
            builtin_entities, tokens, tags, self.config.tagging_scheme)

        grouped_entities = groupby(builtin_entities,
                                   key=lambda s: s[ENTITY_KIND])
        features = None
        for entity, matches in grouped_entities:
            spans_ranges = [match[RES_MATCH_RANGE] for match in matches]
            num_possible_builtins = len(spans_ranges)
            tokens_indexes = _spans_to_tokens_indexes(spans_ranges, tokens)
            related_slots = list(
                set(s for s in builtin_slots_names if
                    self.slot_name_mapping[s] == entity))
            best_updated_tags = augmented_tags
            best_permutation_score = -1

            for slots in _generate_slots_permutations(
                    num_possible_builtins, related_slots,
                    self.config.exhaustive_permutations_threshold):
                updated_tags = copy(augmented_tags)
                for slot_index, slot in enumerate(slots):
                    if slot_index >= len(tokens_indexes):
                        break
                    indexes = tokens_indexes[slot_index]
                    sub_tags_sequence = positive_tagging(
                        self.config.tagging_scheme, slot, len(indexes))
                    updated_tags[indexes[0]:indexes[-1] + 1] = \
                        sub_tags_sequence
                if features is None:
                    features = self.compute_features(tokens)
                score = self._get_sequence_probability(features, updated_tags)
                if score > best_permutation_score:
                    best_updated_tags = updated_tags
                    best_permutation_score = score
            augmented_tags = best_updated_tags
        slots = tags_to_slots(text, tokens, augmented_tags,
                              self.config.tagging_scheme,
                              self.slot_name_mapping)
        return _reconciliate_builtin_slots(text, slots, builtin_entities)
Exemple #7
0
    def get_slots(self, text):
        """Extracts slots from the provided text

        Returns:
            list of dict: The list of extracted slots

        Raises:
            NotTrained: When the slot filler is not fitted
        """
        if not self.slot_name_mapping:
            # Early return if the intent has no slots
            return []

        tokens = tokenize(text, self.language)
        if not tokens:
            return []
        features = self.compute_features(tokens)
        tags = [_decode_tag(tag) for tag in
                self.crf_model.predict_single(features)]
        return tags_to_slots(text, tokens, tags, self.config.tagging_scheme,
                             self.slot_name_mapping)
Exemple #8
0
    def test_bilou_tags_to_slots(self):
        # Given
        language = LANGUAGE_EN
        slot_name = "animal"
        intent_slots_mapping = {"animal": "animal"}
        tags = [
            {
                "text": "",
                "tags": [],
                "expected_slots": []
            },
            {
                "text": "nothing here",
                "tags": [OUTSIDE, OUTSIDE],
                "expected_slots": []
            },
            {
                "text":
                "i am a blue bird",
                "tags": [
                    OUTSIDE, OUTSIDE, OUTSIDE, BEGINNING_PREFIX + slot_name,
                    LAST_PREFIX + slot_name
                ],
                "expected_slots": [
                    unresolved_slot(match_range=(7, 16),
                                    value="blue bird",
                                    entity=slot_name,
                                    slot_name=slot_name)
                ]
            },
            {
                "text":
                "i am a bird",
                "tags": [OUTSIDE, OUTSIDE, OUTSIDE, UNIT_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(match_range=(7, 11),
                                    value="bird",
                                    entity=slot_name,
                                    slot_name=slot_name)
                ]
            },
            {
                "text":
                "bird",
                "tags": [UNIT_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(match_range=(0, 4),
                                    value="bird",
                                    entity=slot_name,
                                    slot_name=slot_name)
                ]
            },
            {
                "text":
                "blue bird",
                "tags":
                [BEGINNING_PREFIX + slot_name, LAST_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(match_range=(0, 9),
                                    value="blue bird",
                                    entity=slot_name,
                                    slot_name=slot_name)
                ]
            },
            {
                "text":
                "light blue bird blue bird",
                "tags": [
                    BEGINNING_PREFIX + slot_name, INSIDE_PREFIX + slot_name,
                    LAST_PREFIX + slot_name, BEGINNING_PREFIX + slot_name,
                    LAST_PREFIX + slot_name
                ],
                "expected_slots": [
                    unresolved_slot(match_range=(0, 15),
                                    value="light blue bird",
                                    entity=slot_name,
                                    slot_name=slot_name),
                    unresolved_slot(match_range=(16, 25),
                                    value="blue bird",
                                    entity=slot_name,
                                    slot_name=slot_name)
                ]
            },
            {
                "text":
                "bird birdy",
                "tags": [UNIT_PREFIX + slot_name, UNIT_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(match_range=(0, 4),
                                    value="bird",
                                    entity=slot_name,
                                    slot_name=slot_name),
                    unresolved_slot(match_range=(5, 10),
                                    value="birdy",
                                    entity=slot_name,
                                    slot_name=slot_name)
                ]
            },
            {
                "text":
                "light bird bird blue bird",
                "tags": [
                    BEGINNING_PREFIX + slot_name, INSIDE_PREFIX + slot_name,
                    UNIT_PREFIX + slot_name, BEGINNING_PREFIX + slot_name,
                    INSIDE_PREFIX + slot_name
                ],
                "expected_slots": [
                    unresolved_slot(match_range=(0, 10),
                                    value="light bird",
                                    entity=slot_name,
                                    slot_name=slot_name),
                    unresolved_slot(match_range=(11, 15),
                                    value="bird",
                                    entity=slot_name,
                                    slot_name=slot_name),
                    unresolved_slot(match_range=(16, 25),
                                    value="blue bird",
                                    entity=slot_name,
                                    slot_name=slot_name)
                ]
            },
            {
                "text":
                "bird bird bird",
                "tags": [
                    LAST_PREFIX + slot_name, BEGINNING_PREFIX + slot_name,
                    UNIT_PREFIX + slot_name
                ],
                "expected_slots": [
                    unresolved_slot(match_range=(0, 4),
                                    value="bird",
                                    entity=slot_name,
                                    slot_name=slot_name),
                    unresolved_slot(match_range=(5, 9),
                                    value="bird",
                                    entity=slot_name,
                                    slot_name=slot_name),
                    unresolved_slot(match_range=(10, 14),
                                    value="bird",
                                    entity=slot_name,
                                    slot_name=slot_name)
                ]
            },
        ]

        for data in tags:
            # When
            slots = tags_to_slots(data["text"],
                                  tokenize(data["text"],
                                           language), data["tags"],
                                  TaggingScheme.BILOU, intent_slots_mapping)
            # Then
            self.assertEqual(slots, data["expected_slots"])
Exemple #9
0
    def test_bilou_tags_to_slots(self):
        # Given
        language = LANGUAGE_EN
        slot_name = "animal"
        intent_slots_mapping = {"animal": "animal"}
        tags = [
            {
                "text": "",
                "tags": [],
                "expected_slots": []
            },
            {
                "text": "nothing here",
                "tags": [OUTSIDE, OUTSIDE],
                "expected_slots": []
            },
            {
                "text": "i am a blue bird",
                "tags": [OUTSIDE, OUTSIDE, OUTSIDE,
                         BEGINNING_PREFIX + slot_name,
                         LAST_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(
                        match_range=(7, 16),
                        value="blue bird",
                        entity=slot_name,
                        slot_name=slot_name
                    )
                ]
            },
            {
                "text": "i am a bird",
                "tags": [OUTSIDE, OUTSIDE, OUTSIDE,
                         UNIT_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(
                        match_range=(7, 11),
                        value="bird",
                        entity=slot_name,
                        slot_name=slot_name
                    )
                ]
            },
            {
                "text": "bird",
                "tags": [UNIT_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(
                        match_range=(0, 4),
                        value="bird",
                        entity=slot_name,
                        slot_name=slot_name
                    )
                ]
            },
            {
                "text": "blue bird",
                "tags": [BEGINNING_PREFIX + slot_name,
                         LAST_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(
                        match_range=(0, 9),
                        value="blue bird",
                        entity=slot_name,
                        slot_name=slot_name
                    )
                ]
            },
            {
                "text": "light blue bird blue bird",
                "tags": [BEGINNING_PREFIX + slot_name,
                         INSIDE_PREFIX + slot_name,
                         LAST_PREFIX + slot_name,
                         BEGINNING_PREFIX + slot_name,
                         LAST_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(
                        match_range=(0, 15),
                        value="light blue bird",
                        entity=slot_name,
                        slot_name=slot_name
                    ),
                    unresolved_slot(
                        match_range=(16, 25),
                        value="blue bird",
                        entity=slot_name,
                        slot_name=slot_name
                    )
                ]
            },
            {
                "text": "bird birdy",
                "tags": [UNIT_PREFIX + slot_name,
                         UNIT_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(
                        match_range=(0, 4),
                        value="bird",
                        entity=slot_name,
                        slot_name=slot_name
                    ),
                    unresolved_slot(
                        match_range=(5, 10),
                        value="birdy",
                        entity=slot_name,
                        slot_name=slot_name
                    )
                ]
            },
            {
                "text": "light bird bird blue bird",
                "tags": [BEGINNING_PREFIX + slot_name,
                         INSIDE_PREFIX + slot_name,
                         UNIT_PREFIX + slot_name,
                         BEGINNING_PREFIX + slot_name,
                         INSIDE_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(
                        match_range=(0, 10),
                        value="light bird",
                        entity=slot_name,
                        slot_name=slot_name
                    ),
                    unresolved_slot(
                        match_range=(11, 15),
                        value="bird",
                        entity=slot_name,
                        slot_name=slot_name
                    ),
                    unresolved_slot(
                        match_range=(16, 25),
                        value="blue bird",
                        entity=slot_name,
                        slot_name=slot_name
                    )
                ]
            },
            {
                "text": "bird bird bird",
                "tags": [LAST_PREFIX + slot_name,
                         BEGINNING_PREFIX + slot_name,
                         UNIT_PREFIX + slot_name],
                "expected_slots": [
                    unresolved_slot(
                        match_range=(0, 4),
                        value="bird",
                        entity=slot_name,
                        slot_name=slot_name
                    ),
                    unresolved_slot(
                        match_range=(5, 9),
                        value="bird",
                        entity=slot_name,
                        slot_name=slot_name
                    ),
                    unresolved_slot(
                        match_range=(10, 14),
                        value="bird",
                        entity=slot_name,
                        slot_name=slot_name
                    )
                ]
            },
        ]

        for data in tags:
            # When
            slots = tags_to_slots(
                data["text"], tokenize(data["text"], language),
                data["tags"], TaggingScheme.BILOU,
                intent_slots_mapping)
            # Then
            self.assertEqual(slots, data["expected_slots"])