コード例 #1
0
def test_compute_precisions_recalls_for_pred_subset_of_true():
    # Every predicted entity is a subset of one true entity
    true_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    pred_entities = [
        Entity(entity_type="LOC", start=4, end=7),
        Entity(entity_type="PER", start=13, end=15),
        Entity(entity_type="LOC", start=25, end=30),
        Entity(entity_type="PER", start=40, end=46),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=4, end=7), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=13, end=15), 1.0),
        EntityPrecision(Entity(entity_type="LOC", start=25, end=30), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=40, end=46), 1.0),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 0.75),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 0.4),
        EntityRecall(Entity(entity_type="LOC", start=23, end=32), 5 / 9),
        EntityRecall(Entity(entity_type="PER", start=37, end=48), 6 / 11),
    ]
コード例 #2
0
def test_compute_precisions_recalls_for_type_doesnt_matter():
    # Overlaps but wrong type
    true_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
    ]
    pred_entities = [
        Entity(entity_type="PER", start=3, end=7),
        Entity(entity_type="LOC", start=10, end=15),
    ]
    # type DOES NOT matter
    label_to_int = {"LOC": 1, "PER": 1}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="PER", start=3, end=7), 1.0),
        EntityPrecision(Entity(entity_type="LOC", start=10, end=15), 1.0),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 1.0),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 1.0),
    ]
コード例 #3
0
def test_compute_precisions_recalls_for_exact_match():
    true_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=3, end=7), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=10, end=15), 1.0),
        EntityPrecision(Entity(entity_type="LOC", start=23, end=32), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=37, end=48), 1.0),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 1.0),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 1.0),
        EntityRecall(Entity(entity_type="LOC", start=23, end=32), 1.0),
        EntityRecall(Entity(entity_type="PER", start=37, end=48), 1.0),
    ]
コード例 #4
0
def test_compute_precisions_recalls_for_no_overlap():
    # No predicted entity overlaps with any true entity
    true_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    pred_entities = [
        Entity(entity_type="LOC", start=15, end=20),
        Entity(entity_type="PER", start=33, end=35),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=15, end=20), 0.0),
        EntityPrecision(Entity(entity_type="PER", start=33, end=35), 0.0),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 0.0),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 0.0),
        EntityRecall(Entity(entity_type="LOC", start=23, end=32), 0.0),
        EntityRecall(Entity(entity_type="PER", start=37, end=48), 0.0),
    ]
コード例 #5
0
def test_compute_precisions_recalls_for_many_preds_to_one_true():
    # Every true entity overlaps more than one predicted entities
    # but its type may not match with all overlapping predicted entities
    true_entities = [
        Entity(entity_type="LOC", start=3, end=20),
        Entity(entity_type="PER", start=28, end=43),
    ]
    pred_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=3, end=7), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=10, end=15), 0.0),
        EntityPrecision(Entity(entity_type="LOC", start=23, end=32), 0.0),
        EntityPrecision(Entity(entity_type="PER", start=37, end=48), 6 / 11),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=20), 4 / 17),
        EntityRecall(Entity(entity_type="PER", start=28, end=43), 0.4),
    ]
コード例 #6
0
def test_compute_precisions_recalls_for_pred_overlap_true():
    # Every predicted entity overlaps with one true entity
    true_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    pred_entities = [
        Entity(entity_type="LOC", start=1, end=4),
        Entity(entity_type="PER", start=13, end=18),
        Entity(entity_type="LOC", start=28, end=35),
        Entity(entity_type="PER", start=45, end=49),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=1, end=4), 1 / 3),
        EntityPrecision(Entity(entity_type="PER", start=13, end=18), 0.4),
        EntityPrecision(Entity(entity_type="LOC", start=28, end=35), 4 / 7),
        EntityPrecision(Entity(entity_type="PER", start=45, end=49), 0.75),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 0.25),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 0.4),
        EntityRecall(Entity(entity_type="LOC", start=23, end=32), 4 / 9),
        EntityRecall(Entity(entity_type="PER", start=37, end=48), 3 / 11),
    ]
コード例 #7
0
def test_calculate_precisions_and_recalls_with_predictions(data):
    data.items[0].pred_labels = [Entity("BIRTHDAY", 0, 10)]
    data.items[1].pred_labels = [
        Entity("ORGANIZATION", 20, 30),
        Entity("LOCATION", 30, 46),
    ]
    grouped_targeted_labels = [{"BIRTHDAY"}, {"ORGANIZATION"}, {"LOCATION"}]

    unwrapped = calculate_precisions_and_recalls(data, grouped_targeted_labels)
    actual = unwrapped["scores"]

    assert len(actual) == 2
    assert actual[0] == TextScore(
        text="It's like that since 12/17/1967",
        precisions=[EntityPrecision(Entity("BIRTHDAY", 0, 10), 0.0)],
        recalls=[EntityRecall(Entity("BIRTHDAY", 21, 31), 0.0)],
    )
    assert actual[1] == TextScore(
        text="The address of Balefire Global is Valadouro 3, Ubide 48145",
        precisions=[
            EntityPrecision(Entity("ORGANIZATION", 20, 30), 1.0),
            EntityPrecision(Entity("LOCATION", 30, 46), 0.75),
        ],
        recalls=[
            EntityRecall(Entity("ORGANIZATION", 15, 30), 2 / 3),
            EntityRecall(Entity("LOCATION", 34, 58), 0.5),
        ],
    )
コード例 #8
0
def test_compute_entity_recalls_for_ground_truth_no_pred_entities():
    actual = compute_entity_recalls_for_ground_truth(
        50, [Entity("PER", 5, 10), Entity("LOC", 15, 25)], [], {
            "PER": 1,
            "LOC": 2
        })
    assert actual == [
        EntityRecall(Entity("PER", 5, 10), 0.0),
        EntityRecall(Entity("LOC", 15, 25), 0.0),
    ]
コード例 #9
0
def test_compute_precisions_recalls_for_no_pred():
    true_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
    ]
    pred_entities: List = []
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == []
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 0.0),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 0.0),
    ]
コード例 #10
0
def scores():
    scores = []
    scores.append(
        TextScore(
            text="It's like that since 9/23/1993",
            precisions=[EntityPrecision(Entity("BIRTHDAY", 0, 10), 0.0)],
            recalls=[EntityRecall(Entity("BIRTHDAY", 21, 31), 0.0)],
        ))
    scores.append(
        TextScore(
            text="The address of Balefire Global is Valadouro 3, Ubide 48145",
            precisions=[
                EntityPrecision(Entity("ORGANIZATION", 20, 30), 1.0),
                EntityPrecision(Entity("LOCATION", 30, 46), 0.75),
            ],
            recalls=[
                EntityRecall(Entity("ORGANIZATION", 15, 30), 2 / 3),
                EntityRecall(Entity("LOCATION", 34, 58), 0.5),
            ],
        ))

    return scores
コード例 #11
0
def test_calculate_precisions_and_recalls_with_nontargeted_labels(data):
    grouped_targeted_labels = [{"ORGANIZATION"}, {"LOCATION"}]
    nontargeted_labels = {"BIRTHDAY", "DATE"}

    unwrapped = calculate_precisions_and_recalls(data, grouped_targeted_labels,
                                                 nontargeted_labels)
    actual = unwrapped["scores"]

    assert len(actual) == 2
    assert actual[0] == TextScore(
        text="It's like that since 12/17/1967",
        precisions=[],
        recalls=[],
    )
    assert actual[1] == TextScore(
        text="The address of Balefire Global is Valadouro 3, Ubide 48145",
        precisions=[],
        recalls=[
            EntityRecall(Entity("ORGANIZATION", 15, 30), 0.0),
            EntityRecall(Entity("LOCATION", 34, 58), 0.0),
        ],
    )
コード例 #12
0
def complex_scores():
    scores = []

    # 1. test grouping i.e. DATE and BIRTHDAY
    scores.append(
        TextScore(
            text="It's like that since 12/17/1967",
            precisions=[EntityPrecision(Entity("DATE", 0, 10), 0.0)],
            recalls=[EntityRecall(Entity("BIRTHDAY", 21, 31), 0.0)],
        ))

    # 2. test removal of non-interested i.e. ORGANIZATION
    scores.append(
        TextScore(
            text="The address of Balefire Global is Valadouro 3, Ubide 48145",
            precisions=[
                EntityPrecision(Entity("ORGANIZATION", 20, 30), 1.0),
                EntityPrecision(Entity("LOCATION", 30, 46), 0.75),
            ],
            recalls=[
                EntityRecall(Entity("ORGANIZATION", 15, 30), 2 / 3),
                EntityRecall(Entity("LOCATION", 34, 58), 0.5),
            ],
        ))

    # 3. test multiple occurrences of a type i.e. LOCATION it occurs in
    # case 4 and 5 as well, this will test calculation on aggregated scores.
    scores.append(
        TextScore(
            text=("Please update billing addrress with Slovenčeva 71, "
                  "Dol pri Ljubljani 1262 for this card: 4539881821557738"),
            precisions=[
                EntityPrecision(Entity("LOCATION", 26, 66), 0.75),
                EntityPrecision(Entity("CREDIT_CARD", 89, 105), 1.0),
            ],
            recalls=[
                EntityRecall(Entity("LOCATION", 36, 73), 30 / 37),
                EntityRecall(Entity("CREDIT_CARD", 89, 105), 1.0),
            ],
        ))

    # 4. test empty precisions
    scores.append(
        TextScore(
            text=("I once lived in Árpád fejedelem útja 89., Bicske 2063. "
                  "I now live in Sarandi 5156, 25 de Agosto 94002"),
            precisions=[],
            recalls=[
                EntityRecall(Entity("LOCATION", 16, 53), 0.0),
                EntityRecall(Entity("LOCATION", 69, 101), 0.0),
            ],
        ))

    # 5. test empty recalls
    scores.append(
        TextScore(
            text="rory is from revelstone",
            precisions=[EntityPrecision(Entity("LOCATION", 13, 23), 0.0)],
            recalls=[],
        ))

    # 6. test empty precisions and recalls
    scores.append(
        TextScore(text="How can I request a new credit card pin ?",
                  precisions=[],
                  recalls=[]))

    return scores