def calculate_precisions_and_recalls(
    data: Data,
    grouped_targeted_labels: List[Set[str]],
    nontargeted_labels: Optional[Set[str]] = None,
) -> Dict[str, List[TextScore]]:
    label_mapping = build_label_mapping(grouped_targeted_labels, nontargeted_labels)

    scores = []
    for item in data.items:
        if item.pred_labels:
            pred_labels = item.pred_labels
        else:  # pred_labels could be None
            pred_labels = []

        ent_precisions = compute_entity_precisions_for_prediction(
            len(item.text), item.true_labels, pred_labels, label_mapping
        )
        ent_recalls = compute_entity_recalls_for_ground_truth(
            len(item.text), item.true_labels, pred_labels, label_mapping
        )
        ticket_score = TextScore(
            text=item.text, precisions=ent_precisions, recalls=ent_recalls
        )
        scores.append(ticket_score)

    return {"scores": scores}
Example #2
0
def test_compute_precisions_recalls_for_exact_match():
    true_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=3, end=7), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=10, end=15), 1.0),
        EntityPrecision(Entity(entity_type="LOC", start=23, end=32), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=37, end=48), 1.0),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 1.0),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 1.0),
        EntityRecall(Entity(entity_type="LOC", start=23, end=32), 1.0),
        EntityRecall(Entity(entity_type="PER", start=37, end=48), 1.0),
    ]
Example #3
0
def test_compute_precisions_recalls_for_pred_subset_of_true():
    # Every predicted entity is a subset of one true entity
    true_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    pred_entities = [
        Entity(entity_type="LOC", start=4, end=7),
        Entity(entity_type="PER", start=13, end=15),
        Entity(entity_type="LOC", start=25, end=30),
        Entity(entity_type="PER", start=40, end=46),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=4, end=7), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=13, end=15), 1.0),
        EntityPrecision(Entity(entity_type="LOC", start=25, end=30), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=40, end=46), 1.0),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 0.75),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 0.4),
        EntityRecall(Entity(entity_type="LOC", start=23, end=32), 5 / 9),
        EntityRecall(Entity(entity_type="PER", start=37, end=48), 6 / 11),
    ]
Example #4
0
def test_compute_entity_precisions_for_prediction_no_pred_entities():
    actual = compute_entity_precisions_for_prediction(
        50, [Entity("PER", 5, 10), Entity("LOC", 15, 25)], [], {
            "PER": 1,
            "LOC": 2
        })
    assert actual == []
Example #5
0
def test_compute_precisions_recalls_for_type_doesnt_matter():
    # Overlaps but wrong type
    true_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
    ]
    pred_entities = [
        Entity(entity_type="PER", start=3, end=7),
        Entity(entity_type="LOC", start=10, end=15),
    ]
    # type DOES NOT matter
    label_to_int = {"LOC": 1, "PER": 1}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="PER", start=3, end=7), 1.0),
        EntityPrecision(Entity(entity_type="LOC", start=10, end=15), 1.0),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 1.0),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 1.0),
    ]
Example #6
0
def test_compute_precisions_recalls_for_many_preds_to_one_true():
    # Every true entity overlaps more than one predicted entities
    # but its type may not match with all overlapping predicted entities
    true_entities = [
        Entity(entity_type="LOC", start=3, end=20),
        Entity(entity_type="PER", start=28, end=43),
    ]
    pred_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=3, end=7), 1.0),
        EntityPrecision(Entity(entity_type="PER", start=10, end=15), 0.0),
        EntityPrecision(Entity(entity_type="LOC", start=23, end=32), 0.0),
        EntityPrecision(Entity(entity_type="PER", start=37, end=48), 6 / 11),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=20), 4 / 17),
        EntityRecall(Entity(entity_type="PER", start=28, end=43), 0.4),
    ]
Example #7
0
def test_compute_precisions_recalls_for_no_overlap():
    # No predicted entity overlaps with any true entity
    true_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    pred_entities = [
        Entity(entity_type="LOC", start=15, end=20),
        Entity(entity_type="PER", start=33, end=35),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=15, end=20), 0.0),
        EntityPrecision(Entity(entity_type="PER", start=33, end=35), 0.0),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 0.0),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 0.0),
        EntityRecall(Entity(entity_type="LOC", start=23, end=32), 0.0),
        EntityRecall(Entity(entity_type="PER", start=37, end=48), 0.0),
    ]
Example #8
0
def test_compute_precisions_recalls_for_pred_overlap_true():
    # Every predicted entity overlaps with one true entity
    true_entities = pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
        Entity(entity_type="LOC", start=23, end=32),
        Entity(entity_type="PER", start=37, end=48),
    ]
    pred_entities = [
        Entity(entity_type="LOC", start=1, end=4),
        Entity(entity_type="PER", start=13, end=18),
        Entity(entity_type="LOC", start=28, end=35),
        Entity(entity_type="PER", start=45, end=49),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=1, end=4), 1 / 3),
        EntityPrecision(Entity(entity_type="PER", start=13, end=18), 0.4),
        EntityPrecision(Entity(entity_type="LOC", start=28, end=35), 4 / 7),
        EntityPrecision(Entity(entity_type="PER", start=45, end=49), 0.75),
    ]
    assert recalls == [
        EntityRecall(Entity(entity_type="LOC", start=3, end=7), 0.25),
        EntityRecall(Entity(entity_type="PER", start=10, end=15), 0.4),
        EntityRecall(Entity(entity_type="LOC", start=23, end=32), 4 / 9),
        EntityRecall(Entity(entity_type="PER", start=37, end=48), 3 / 11),
    ]
Example #9
0
def test_compute_precisions_recalls_for_no_true():
    true_entities: List = []
    pred_entities = [
        Entity(entity_type="LOC", start=3, end=7),
        Entity(entity_type="PER", start=10, end=15),
    ]
    # type matters
    label_to_int = {"LOC": 1, "PER": 2}

    precisions = compute_entity_precisions_for_prediction(
        50, true_entities, pred_entities, label_to_int)
    recalls = compute_entity_recalls_for_ground_truth(50, true_entities,
                                                      pred_entities,
                                                      label_to_int)
    assert precisions == [
        EntityPrecision(Entity(entity_type="LOC", start=3, end=7), 0.0),
        EntityPrecision(Entity(entity_type="PER", start=10, end=15), 0.0),
    ]
    assert recalls == []
Example #10
0
def test_compute_entity_precisions_for_prediction_no_true_no_pred_entities():
    actual = compute_entity_precisions_for_prediction(50, [], [], {
        "PER": 1,
        "LOC": 2
    })
    assert actual == []