示例#1
0
def test_determine_token_labels_with_extractors():
    determine_token_labels(
        CH_correct_segmentation[0],
        [CH_correct_entity, CH_wrong_entity],
        [SpacyEntityExtractor.name, MitieEntityExtractor.name],
    )
示例#2
0
def test_determine_token_labels_no_extractors():
    with pytest.raises(ValueError):
        determine_token_labels(CH_correct_segmentation[0],
                               [CH_correct_entity, CH_wrong_entity], None)
示例#3
0
def test_determine_token_labels_no_extractors_no_overlap():
    determine_token_labels(CH_correct_segmentation[0], EN_targets, None)
示例#4
0
def test_determine_token_labels_no_extractors_no_overlap():
    label = determine_token_labels(CH_correct_segmentation[0], EN_targets,
                                   None)
    assert label == NO_ENTITY_TAG
示例#5
0
def test_determine_token_labels_throws_error():
    with pytest.raises(ValueError):
        determine_token_labels(CH_correct_segmentation[0],
                               [CH_correct_entity,
                                CH_wrong_entity], ["CRFEntityExtractor"])
示例#6
0
    def _create_model_data(
        self,
        training_data: List[Message],
        label_id_dict: Optional[Dict[Text, int]] = None,
        tag_id_dict: Optional[Dict[Text, int]] = None,
        label_attribute: Optional[Text] = None,
    ) -> RasaModelData:
        """Prepare data for training and create a RasaModelData object"""

        X_sparse = []
        X_dense = []
        Y_sparse = []
        Y_dense = []
        label_ids = []
        tag_ids = []

        for e in training_data:
            if label_attribute is None or e.get(label_attribute):
                _sparse, _dense = self._extract_features(e, TEXT)
                if _sparse is not None:
                    X_sparse.append(_sparse)
                if _dense is not None:
                    X_dense.append(_dense)

            if e.get(label_attribute):
                _sparse, _dense = self._extract_features(e, label_attribute)
                if _sparse is not None:
                    Y_sparse.append(_sparse)
                if _dense is not None:
                    Y_dense.append(_dense)

                if label_id_dict:
                    label_ids.append(label_id_dict[e.get(label_attribute)])

            if self.component_config.get(ENTITY_RECOGNITION) and tag_id_dict:
                if self.component_config[BILOU_FLAG]:
                    _tags = bilou_utils.tags_to_ids(e, tag_id_dict)
                else:
                    _tags = []
                    for t in e.get(TOKENS_NAMES[TEXT]):
                        _tag = determine_token_labels(t, e.get(ENTITIES), None)
                        _tags.append(tag_id_dict[_tag])
                # transpose to have seq_len x 1
                tag_ids.append(np.array([_tags]).T)

        X_sparse = np.array(X_sparse)
        X_dense = np.array(X_dense)
        Y_sparse = np.array(Y_sparse)
        Y_dense = np.array(Y_dense)
        label_ids = np.array(label_ids)
        tag_ids = np.array(tag_ids)

        model_data = RasaModelData(label_key=self.label_key)
        model_data.add_features(TEXT_FEATURES, [X_sparse, X_dense])
        model_data.add_features(LABEL_FEATURES, [Y_sparse, Y_dense])
        if label_attribute and model_data.feature_not_exist(LABEL_FEATURES):
            # no label features are present, get default features from _label_data
            model_data.add_features(
                LABEL_FEATURES, self._use_default_label_features(label_ids))

        # explicitly add last dimension to label_ids
        # to track correctly dynamic sequences
        model_data.add_features(LABEL_IDS, [np.expand_dims(label_ids, -1)])
        model_data.add_features(TAG_IDS, [tag_ids])

        model_data.add_mask(TEXT_MASK, TEXT_FEATURES)
        model_data.add_mask(LABEL_MASK, LABEL_FEATURES)

        return model_data
示例#7
0
def test_determine_token_labels_no_extractors():
    label = determine_token_labels(CH_correct_segmentation[0],
                                   [CH_correct_entity, CH_wrong_entity], None)
    assert label == "direction"