Beispiel #1
0
def test_crf_recogniser_analyse(mock_tokeniser):
    recogniser = CrfRecogniser(
        ["PER", "LOC"],
        ["en"],
        "fake_path",
        tokeniser_setup={
            "name": "fake_tokeniser",
            "config": {
                "fake_param": "fake_value"
            },
        },
    )

    actual = recogniser.analyse("fake_text", entities=["PER"])
    mock_tokeniser.assert_called_with("fake_tokeniser",
                                      {"fake_param": "fake_value"})
    assert actual == [Entity("PER", 8, 11)]

    actual = recogniser.analyse("fake_text", entities=["PER", "LOC"])
    assert actual == [Entity("PER", 8, 11), Entity("LOC", 17, 26)]

    with pytest.raises(AssertionError) as err:
        recogniser.analyse("fake_text", entities=["PER", "LOC", "TIME"])
    assert (str(err.value) ==
            "Only support ['PER', 'LOC'], but got ['PER', 'LOC', 'TIME']")
Beispiel #2
0
def test_registry_for_comprehend(mock_analyse, mock_session):
    mock_analyse.return_value = [Entity("test", 0, 4)]

    recogniser_name = "ComprehendRecogniser"
    recogniser_params = {
        "supported_entities": [
            "COMMERCIAL_ITEM",
            "DATE",
            "EVENT",
            "LOCATION",
            "ORGANIZATION",
            "OTHER",
            "PERSON",
            "QUANTITY",
            "TITLE",
        ],
        "supported_languages": ["en"],
        "model_name":
        "pii",
    }

    recogniser = recogniser_registry.create_instance(recogniser_name,
                                                     recogniser_params)
    actual = recogniser.analyse("test text", recogniser.supported_entities)

    assert actual == [Entity("test", 0, 4)]
Beispiel #3
0
def test_registry_for_google_recogniser(mock_analyse, mock_client):
    mock_analyse.return_value = [Entity("test", 0, 4)]

    recogniser_name = "GoogleRecogniser"
    recogniser_params = {
        "supported_entities": [
            "UNKNOWN",
            "PERSON",
            "LOCATION",
            "ORGANIZATION",
            "EVENT",
            "WORK_OF_ART",
            "CONSUMER_GOOD",
            "OTHER",
            "PHONE_NUMBER",
            "ADDRESS",
            "DATE",
            "NUMBER",
            "PRICE",
        ],
        "supported_languages": ["en"],
    }

    recogniser = recogniser_registry.create_instance(recogniser_name,
                                                     recogniser_params)
    actual = recogniser.analyse("test text", recogniser.supported_entities)

    assert actual == [Entity("test", 0, 4)]
Beispiel #4
0
def test_compute_entity_precisions_for_prediction_no_pred_entities():
    actual = compute_entity_precisions_for_prediction(
        50, [Entity("PER", 5, 10), Entity("LOC", 15, 25)], [], {
            "PER": 1,
            "LOC": 2
        })
    assert actual == []
Beispiel #5
0
def test_compute_entity_recalls_for_ground_truth_no_true_entities():
    actual = compute_entity_recalls_for_ground_truth(
        50, [],
        [Entity("PER", 5, 10), Entity("LOC", 15, 25)], {
            "PER": 1,
            "LOC": 2
        })
    assert actual == []
def mock_recogniser():
    recogniser = Mock()
    recogniser.analyse.return_value = [
        Entity("PER", 8, 11),
        Entity("LOC", 17, 26),
    ]
    recogniser.supported_entities = ["PER", "LOC"]
    return recogniser
Beispiel #7
0
def test_label_encoder_for_missing_label_in_mapping():
    spans = [
        Entity(entity_type="LOC", start=5, end=8),
        Entity(entity_type="PER", start=10, end=15),
    ]

    with pytest.raises(Exception) as error:
        label_encoder(20, spans, {"LOC": 1})
    assert str(
        error.value) == ("Missing label 'PER' in 'label_to_int' mapping.")
Beispiel #8
0
def test_calculate_precisions_and_recalls_with_predictions(data):
    data.items[0].pred_labels = [Entity("BIRTHDAY", 0, 10)]
    data.items[1].pred_labels = [
        Entity("ORGANIZATION", 20, 30),
        Entity("LOCATION", 30, 46),
    ]
    grouped_targeted_labels = [{"BIRTHDAY"}, {"ORGANIZATION"}, {"LOCATION"}]

    unwrapped = calculate_precisions_and_recalls(data, grouped_targeted_labels)
    actual = unwrapped["scores"]

    assert len(actual) == 2
    assert actual[0] == TextScore(
        text="It's like that since 12/17/1967",
        precisions=[EntityPrecision(Entity("BIRTHDAY", 0, 10), 0.0)],
        recalls=[EntityRecall(Entity("BIRTHDAY", 21, 31), 0.0)],
    )
    assert actual[1] == TextScore(
        text="The address of Balefire Global is Valadouro 3, Ubide 48145",
        precisions=[
            EntityPrecision(Entity("ORGANIZATION", 20, 30), 1.0),
            EntityPrecision(Entity("LOCATION", 30, 46), 0.75),
        ],
        recalls=[
            EntityRecall(Entity("ORGANIZATION", 15, 30), 2 / 3),
            EntityRecall(Entity("LOCATION", 34, 58), 0.5),
        ],
    )
Beispiel #9
0
def test_label_encoder_for_multi_labels():
    spans = [
        Entity(entity_type="LOC", start=5, end=8),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="PERSON", start=2, end=5),
    ]

    # entity PER and PERSON map to the same int
    actual = label_encoder(20, spans, {"LOC": 1, "PER": 2, "PERSON": 2})
    assert actual == [
        0, 0, 2, 2, 2, 1, 1, 1, 0, 0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0
    ]
Beispiel #10
0
def test_compute_precisions_recalls_for_type_doesnt_matter():
    # Overlaps but wrong type
    true_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
    ]
    pred_entities = [
        Entity(entity_type="PER", start=3, end=7),
        Entity(entity_type="LOC", start=10, end=15),
    ]
    # type DOES NOT matter
    label_to_int = {"LOC": 1, "PER": 1}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="PER", start=3, end=7), 1.0),
        EntityPrecision(Entity(entity_type="LOC", start=10, end=15), 1.0),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 1.0),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 1.0),
    ]
def test_spacy_recogniser(text):
    recogniser = SpacyRecogniser(["PER", "LOC"], ["en"], model_name="fake_model")

    actual = recogniser.analyse(text, entities=["PER"])
    assert actual == [Entity("PER", 8, 11)]

    actual = recogniser.analyse(text, entities=["PER", "LOC"])
    assert actual == [Entity("PER", 8, 11), Entity("LOC", 17, 26)]

    with pytest.raises(AssertionError) as err:
        recogniser.analyse(text, entities=["PER", "LOC", "TIME"])
    assert (
        str(err.value) == "Only support ['PER', 'LOC'], but got ['PER', 'LOC', 'TIME']"
    )
def predictions():
    # this case tests on recall threshold
    entities_1st_text = [
        Entity("LOCATION", 36, 48),
        Entity("OTHER", 75, 91),
    ]

    # this case tests empty precisions
    entities_2nd_text = [
        Entity("TITLE", 84, 87),
        Entity("EVENT", 88, 105),
    ]

    entities_3rd_text = [
        Entity("PERSON", 13, 25),
        Entity("PERSON", 28, 35),
    ]

    # this case tests empty recalls
    entities_4th_text = [
        Entity("PERSON", 11, 17),
    ]

    # this case tests empty precisions and recalls
    entities_5th_text = [
        Entity("ORGANIZATION", 11, 21),
    ]

    return [
        entities_1st_text,
        entities_2nd_text,
        entities_3rd_text,
        entities_4th_text,
        entities_5th_text,
    ]
def test_first_letter_uppercase_analyse(mock_tokeniser):
    recogniser = FirstLetterUppercaseRecogniser(
        ["PER"],
        ["en"],
        tokeniser_setup={
            "name": "fake_tokeniser",
            "config": {"fake_param": "fake_value"},
        },
    )
    mock_tokeniser.assert_called_with("fake_tokeniser", {"fake_param": "fake_value"})
    actual = recogniser.analyse("fake_text", entities=["PER"])
    assert actual == [
        Entity("PER", 0, 4),
        Entity("PER", 8, 11),
        Entity("PER", 17, 26),
    ]
def mock_bad_recogniser():
    # failed to predict location entity
    recogniser = Mock()
    recogniser.analyse.return_value = [
        Entity("PER", 8, 11),
    ]
    recogniser.supported_entities = ["PER", "LOC"]
    return recogniser
Beispiel #15
0
def test_label_encoder_for_span_beyond_range():
    spans = [Entity(entity_type="LOC", start=3, end=7)]

    with pytest.raises(ValueError) as error:
        label_encoder(5, spans, {"LOC": 1})
    assert str(error.value) == (
        "Entity span index is out of range: text length is 5 but got span index 7."
    )
Beispiel #16
0
def data():
    items = [
        DataItem("It's like that since 12/17/1967",
                 true_labels=[Entity("BIRTHDAY", 21, 31)]),
        DataItem(
            "The address of Balefire Global is Valadouro 3, Ubide 48145",
            true_labels=[
                Entity("ORGANIZATION", 15, 30),
                Entity("LOCATION", 34, 58)
            ],
        ),
    ]

    return Data(
        items,
        supported_entities={"BIRTHDAY", "ORGANIZATION", "LOCATION"},
        is_io_schema=False,
    )
Beispiel #17
0
def test_compute_precisions_recalls_for_no_true():
    true_entities: List = []
    pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=3, end=7), 0.0),
        EntityPrecision(Entity(entity_type="PER", start=10, end=15), 0.0),
    ]
    assert recalls == []
def test_flair_analyse():
    recogniser = FlairRecogniser(
        supported_entities=["PER", "LOC", "ORG", "MISC"],
        supported_languages=["en"],
        model_name="fake_model",
    )

    actual = recogniser.analyse(text, entities=["PER"])
    assert actual == [Entity("PER", 8, 11)]

    actual = recogniser.analyse(text, entities=["PER", "LOC"])
    assert actual == [Entity("PER", 8, 11), Entity("LOC", 17, 26)]

    with pytest.raises(AssertionError) as err:
        recogniser.analyse(text, entities=["PER", "LOC", "TIME"])
    assert (
        str(err.value)
        == "Only support ['PER', 'LOC', 'ORG', 'MISC'], but got ['PER', 'LOC', 'TIME']"
    )
Beispiel #19
0
def test_stanza_analyse(text):
    recogniser = StanzaRecogniser(
        supported_entities=["PERSON", "LOC", "ORG"],
        supported_languages=["en"],
        model_name="en",
    )

    actual = recogniser.analyse(text, entities=["PERSON"])
    assert actual == [Entity("PERSON", 8, 11)]

    actual = recogniser.analyse(text, entities=["PERSON", "LOC"])
    assert actual == [Entity("PERSON", 8, 11), Entity("LOC", 17, 26)]

    with pytest.raises(AssertionError) as err:
        recogniser.analyse(text, entities=["PERSON", "LOC", "TIME"])
    assert (
        str(err.value) ==
        "Only support ['PERSON', 'LOC', 'ORG'], but got ['PERSON', 'LOC', 'TIME']"
    )
    def analyse(self, text: str, entities: List[str]) -> List[Entity]:
        self.validate_entities(entities)

        results = self.model(text)

        span_labels = []
        for entity in results.entities:
            if entity.type in entities:
                span_labels.append(
                    Entity(entity.type, entity.start_char, entity.end_char))
        return span_labels
    def analyse(self, text: str, entities: List[str]) -> List[Entity]:
        self.validate_entities(entities)

        sentence = Sentence(text)
        self.model.predict(sentence)

        span_labels = []
        for entity in sentence.get_spans("ner"):
            if entity.tag in entities:
                span_labels.append(Entity(entity.tag, entity.start_pos, entity.end_pos))

        return span_labels
Beispiel #22
0
def test_identify_pii_entities(mock_registry, data):
    mock_registry.create_instance.return_value.analyse.return_value = [
        Entity("test", 0, 4)
    ]

    actual = identify_pii_entities(
        data,
        "test_recogniser",
        {
            "supported_entities": ["test"],
            "supported_languages": ["test"]
        },
    )

    assert [item.text for item in actual.items] == [
        "It's like that since 12/17/1967",
        "The address of Balefire Global is Valadouro 3, Ubide 48145",
    ]
    assert [item.true_labels for item in actual.items] == [
        [Entity("BIRTHDAY", 21, 31)],
        [Entity("ORGANIZATION", 15, 30),
         Entity("LOCATION", 34, 58)],
    ]
    assert [item.pred_labels for item in actual.items] == [
        [Entity("test", 0, 4)],
        [Entity("test", 0, 4)],
    ]
def test_comprehend_recogniser_analyse(text, fake_response):
    # mock API call
    mock_model = MagicMock()
    mock_model.return_value = fake_response

    recogniser = ComprehendRecogniser(
        supported_entities=["LOCATION", "OTHER"],
        supported_languages=["en"],
        model_name="ner",
    )
    recogniser.model_func = mock_model

    spans = recogniser.analyse(text, recogniser.supported_entities)
    assert spans == [Entity("OTHER", 83, 99), Entity("LOCATION", 137, 171)]

    spans = recogniser.analyse(text, ["OTHER"])
    assert spans == [
        Entity("OTHER", 83, 99),
    ]

    spans = recogniser.analyse(text, ["LOCATION"])
    assert spans == [Entity("LOCATION", 137, 171)]
def test_get_span_based_prediction(mock_recogniser, mock_tokeniser, text):
    # test 1: succeed
    evaluator = ModelEvaluator(
        recogniser=mock_recogniser,
        tokeniser=mock_tokeniser,
        target_entities=["PER", "LOC"],
    )
    actual = evaluator.get_span_based_prediction(text)
    assert actual == [
        Entity(entity_type="PER", start=8, end=11),
        Entity(entity_type="LOC", start=17, end=26),
    ]

    # test 2: raise assertion error
    evaluator = ModelEvaluator(
        recogniser=mock_recogniser,
        tokeniser=mock_tokeniser,
        target_entities=["PER"],
    )
    with pytest.raises(AssertionError) as err:
        evaluator.get_span_based_prediction(text)
    assert str(err.value) == "Predictions contain unasked entities ['LOC']"
Beispiel #25
0
def test_calculate_precisions_and_recalls_with_nontargeted_labels(data):
    grouped_targeted_labels = [{"ORGANIZATION"}, {"LOCATION"}]
    nontargeted_labels = {"BIRTHDAY", "DATE"}

    unwrapped = calculate_precisions_and_recalls(data, grouped_targeted_labels,
                                                 nontargeted_labels)
    actual = unwrapped["scores"]

    assert len(actual) == 2
    assert actual[0] == TextScore(
        text="It's like that since 12/17/1967",
        precisions=[],
        recalls=[],
    )
    assert actual[1] == TextScore(
        text="The address of Balefire Global is Valadouro 3, Ubide 48145",
        precisions=[],
        recalls=[
            EntityRecall(Entity("ORGANIZATION", 15, 30), 0.0),
            EntityRecall(Entity("LOCATION", 34, 58), 0.0),
        ],
    )
Beispiel #26
0
def test_google_recogniser_for_analyse(mock_client, text, response):
    mock_client.analyze_entities.return_value = response

    recogniser = GoogleRecogniser(
        supported_entities=[
            "UNKNOWN",
            "PERSON",
            "LOCATION",
            "ORGANIZATION",
            "EVENT",
            "WORK_OF_ART",
            "CONSUMER_GOOD",
            "OTHER",
            "PHONE_NUMBER",
            "ADDRESS",
            "DATE",
            "NUMBER",
            "PRICE",
        ],
        supported_languages=["en"],
    )
    actual = recogniser.analyse(text, recogniser.supported_entities)
    assert actual == [Entity("OTHER", 14, 21), Entity("NUMBER", 42, 44)]
Beispiel #27
0
    def analyse(self, text: str, entities: List[str]) -> List[Entity]:
        self.validate_entities(entities)

        doc = self.model(text)
        spacy_entities = [entity for entity in doc.ents]

        filtered_entities = list(
            filter(lambda x: x.label_ in entities, spacy_entities))

        return [
            Entity(entity_type=entity.label_,
                   start=entity.start_char,
                   end=entity.end_char) for entity in filtered_entities
        ]
    def analyse(self, text: str, entities: List[str]) -> List[Entity]:
        self.validate_entities(entities)

        # TODO: Add multilingual support
        # based on boto3 Comprehend doc Comprehend supports
        # 'en'|'es'|'fr'|'de'|'it'|'pt'|'ar'|'hi'|'ja'|'ko'|'zh'|'zh-TW'
        DEFAULT_LANG = "en"

        response = self.model_func(Text=text, LanguageCode=DEFAULT_LANG)

        # parse response
        predicted_entities = response["Entities"]
        # Remove entities we are not interested
        filtered = filter(lambda ent: ent["Type"] in entities, predicted_entities)
        span_labels = map(
            lambda ent: Entity(ent["Type"], ent["BeginOffset"], ent["EndOffset"]),
            filtered,
        )

        return list(span_labels)
    def _parse_response(
        self, response: AnalyzeEntitiesResponse, indexer: TextIndexer
    ) -> List[Entity]:
        span_labels = []
        for entity in response.entities:
            entity_type = entity.type_.name
            for mention in entity.mentions:
                # Three types of mention: PROPER, COMMON and TYPE_UNKNOWN. We are not
                # interested in COMMON.
                # https://cloud.google.com/natural-language/docs/basics#entity_analysis
                if mention.type_.name != "COMMON":
                    # google is using byte offset
                    byte_begin_offset = mention.text.begin_offset
                    start = indexer.byte_index_to_utf8_index(byte_begin_offset)
                    # content is decoded in chosen langauge which is UTF8
                    text_length = len(mention.text.content)
                    end = start + text_length
                    span_labels.append(
                        Entity(entity_type=entity_type, start=start, end=end)
                    )

        return span_labels
Beispiel #30
0
def scores():
    scores = []
    scores.append(
        TextScore(
            text="It's like that since 9/23/1993",
            precisions=[EntityPrecision(Entity("BIRTHDAY", 0, 10), 0.0)],
            recalls=[EntityRecall(Entity("BIRTHDAY", 21, 31), 0.0)],
        ))
    scores.append(
        TextScore(
            text="The address of Balefire Global is Valadouro 3, Ubide 48145",
            precisions=[
                EntityPrecision(Entity("ORGANIZATION", 20, 30), 1.0),
                EntityPrecision(Entity("LOCATION", 30, 46), 0.75),
            ],
            recalls=[
                EntityRecall(Entity("ORGANIZATION", 15, 30), 2 / 3),
                EntityRecall(Entity("LOCATION", 34, 58), 0.5),
            ],
        ))

    return scores