예제 #1
0
def test_compute_metrics_agg_scenario_3():

    true_named_entities = [{"label": "PER", "start":59, "end":69}]

    pred_named_entities = []

    results, results_agg = compute_metrics(
        true_named_entities, pred_named_entities, ['PER']
    )

    expected_agg = {
        'PER': {
            'strict': {
                'correct': 0,
                'incorrect': 0,
                'partial': 0,
                'missed': 1,
                'spurious': 0,
                'actual': 0,
                'possible': 1,
                'precision': 0,
                'recall': 0,
                'f1': 0
                },
            'ent_type': {
                'correct': 0,
                'incorrect': 0,
                'partial': 0,
                'missed': 1,
                'spurious': 0,
                'actual': 0,
                'possible': 1,
                'precision': 0,
                'recall': 0,
                'f1': 0
            },
            'partial': {
                'correct': 0,
                'incorrect': 0,
                'partial': 0,
                'missed': 1,
                'spurious': 0,
                'actual': 0,
                'possible': 1,
                'precision': 0,
                'recall': 0,
                'f1': 0
            },
            'exact': {
                'correct': 0,
                'incorrect': 0,
                'partial': 0,
                'missed': 1,
                'spurious': 0,
                'actual': 0,
                'possible': 1,
                'precision': 0,
                'recall': 0,
                'f1': 0
            }
        }
    }

    assert results_agg['PER']['strict'] == expected_agg['PER']['strict']
    assert results_agg['PER']['ent_type'] == expected_agg['PER']['ent_type']
    assert results_agg['PER']['partial'] == expected_agg['PER']['partial']
    assert results_agg['PER']['exact'] == expected_agg['PER']['exact']
예제 #2
0
def test_compute_metrics_case_1():
    true_named_entities = [
        {"label":"PER", "start": 59, "end": 69},
        {"label":"LOC", "start": 127, "end": 134},
        {"label":"LOC", "start": 164, "end": 174},
        {"label":"LOC", "start": 197, "end": 205},
        {"label":"LOC", "start": 208, "end": 219},
        {"label":"MISC", "start": 230, "end": 240},
    ]

    pred_named_entities = [
        {"label":"PER", "start": 24, "end": 30},
        {"label":"LOC", "start": 124, "end": 134},
        {"label":"PER", "start": 164, "end": 174},
        {"label":"LOC", "start": 197, "end": 205},
        {"label":"LOC", "start": 208, "end": 219},
        {"label":"LOC", "start": 225, "end": 243},
    ]

    results, results_agg = compute_metrics(
        true_named_entities, pred_named_entities, ['PER', 'LOC', 'MISC']
    )

    results = compute_precision_recall_wrapper(results)

    expected = {'strict': {'correct': 2,
                           'incorrect': 3,
                           'partial': 0,
                           'missed': 1,
                           'spurious': 1,
                           'possible': 6,
                           'actual': 6,
                           'precision': 0.3333333333333333,
                           'recall': 0.3333333333333333,
                           'f1': 0.3333333333333333},
                'ent_type': {'correct': 3,
                             'incorrect': 2,
                             'partial': 0,
                             'missed': 1,
                             'spurious': 1,
                             'possible': 6,
                             'actual': 6,
                             'precision': 0.5,
                             'recall': 0.5,
                             'f1': 0.5},
                'partial': {'correct': 3,
                            'incorrect': 0,
                            'partial': 2,
                            'missed': 1,
                            'spurious': 1,
                            'possible': 6,
                            'actual': 6,
                            'precision': 0.6666666666666666,
                            'recall': 0.6666666666666666,
                            'f1': 0.6666666666666666},
                'exact': {'correct': 3,
                          'incorrect': 2,
                          'partial': 0,
                          'missed': 1,
                          'spurious': 1,
                          'possible': 6,
                          'actual': 6,
                          'precision': 0.5,
                          'recall': 0.5,
                          'f1': 0.5}
                }

    assert results == expected
예제 #3
0
def test_compute_metrics_no_predictions():

    true_named_entities = [
        {"label": "PER", "start": 50, "end": 52},
        {"label": "ORG", "start": 59, "end": 69},
        {"label":"MISC", "start": 71, "end": 72},
    ]

    pred_named_entities = []

    results, results_agg = compute_metrics(
        true_named_entities, pred_named_entities, ['PER', 'ORG', 'MISC']
    )

    expected = {
        'strict': {
            'correct': 0,
            'incorrect': 0,
            'partial': 0,
            'missed': 3,
            'spurious': 0,
            'actual': 0,
            'possible': 3,
            'precision': 0,
            'recall': 0,
            'f1': 0
            },
        'ent_type': {
            'correct': 0,
            'incorrect': 0,
            'partial': 0,
            'missed': 3,
            'spurious': 0,
            'actual': 0,
            'possible': 3,
            'precision': 0,
            'recall': 0,
            'f1': 0
        },
        'partial': {
            'correct': 0,
            'incorrect': 0,
            'partial': 0,
            'missed': 3,
            'spurious': 0,
            'actual': 0,
            'possible': 3,
            'precision': 0,
            'recall': 0,
            'f1': 0
        },
        'exact': {
            'correct': 0,
            'incorrect': 0,
            'partial': 0,
            'missed': 3,
            'spurious': 0,
            'actual': 0,
            'possible': 3,
            'precision': 0,
            'recall': 0,
            'f1': 0
        }
    }

    assert results['strict'] == expected['strict']
    assert results['ent_type'] == expected['ent_type']
    assert results['partial'] == expected['partial']
    assert results['exact'] == expected['exact']
예제 #4
0
def test_compute_metrics_extra_tags_in_true():

    true_named_entities = [
        {"label": "PER", "start": 50, "end": 52},
        {"label": "ORG", "start": 59, "end": 69},
        {"label":"MISC", "start": 71, "end": 72},
    ]

    pred_named_entities = [
        {"label":"LOC", "start": 50, "end": 52},  # Wrong type
        {"label":"ORG", "start": 59, "end": 69},  # Correct
        {"label":"ORG", "start": 71, "end": 72},  # Spurious
    ]

    results, results_agg = compute_metrics(
        true_named_entities, pred_named_entities, ['PER', 'LOC', 'ORG']
    )

    expected = {
        'strict': {
            'correct': 1,
            'incorrect': 1,
            'partial': 0,
            'missed': 0,
            'spurious': 1,
            'actual': 3,
            'possible': 2,
            'precision': 0,
            'recall': 0,
            'f1': 0
            },
        'ent_type': {
            'correct': 1,
            'incorrect': 1,
            'partial': 0,
            'missed': 0,
            'spurious': 1,
            'actual': 3,
            'possible': 2,
            'precision': 0,
            'recall': 0,
            'f1': 0
        },
        'partial': {
            'correct': 2,
            'incorrect': 0,
            'partial': 0,
            'missed': 0,
            'spurious': 1,
            'actual': 3,
            'possible': 2,
            'precision': 0,
            'recall': 0,
            'f1': 0
        },
        'exact': {
            'correct': 2,
            'incorrect': 0,
            'partial': 0,
            'missed': 0,
            'spurious': 1,
            'actual': 3,
            'possible': 2,
            'precision': 0,
            'recall': 0,
            'f1': 0
        }
    }

    assert results['strict'] == expected['strict']
    assert results['ent_type'] == expected['ent_type']
    assert results['partial'] == expected['partial']
    assert results['exact'] == expected['exact']
예제 #5
0
def test_compute_metrics_no_predictions():

    true_named_entities = [
        {
            "label": "PER",
            "start": 50,
            "end": 52
        },
        {
            "label": "ORG",
            "start": 59,
            "end": 69
        },
        {
            "label": "MISC",
            "start": 71,
            "end": 72
        },
    ]

    pred_named_entities = []

    results, results_agg = compute_metrics(true_named_entities,
                                           pred_named_entities,
                                           ["PER", "ORG", "MISC"])

    expected = {
        "strict": {
            "correct": 0,
            "incorrect": 0,
            "partial": 0,
            "missed": 3,
            "spurious": 0,
            "actual": 0,
            "possible": 3,
            "precision": 0,
            "recall": 0,
            "f1": 0,
        },
        "ent_type": {
            "correct": 0,
            "incorrect": 0,
            "partial": 0,
            "missed": 3,
            "spurious": 0,
            "actual": 0,
            "possible": 3,
            "precision": 0,
            "recall": 0,
            "f1": 0,
        },
        "partial": {
            "correct": 0,
            "incorrect": 0,
            "partial": 0,
            "missed": 3,
            "spurious": 0,
            "actual": 0,
            "possible": 3,
            "precision": 0,
            "recall": 0,
            "f1": 0,
        },
        "exact": {
            "correct": 0,
            "incorrect": 0,
            "partial": 0,
            "missed": 3,
            "spurious": 0,
            "actual": 0,
            "possible": 3,
            "precision": 0,
            "recall": 0,
            "f1": 0,
        },
    }

    assert results["strict"] == expected["strict"]
    assert results["ent_type"] == expected["ent_type"]
    assert results["partial"] == expected["partial"]
    assert results["exact"] == expected["exact"]
예제 #6
0
def test_compute_metrics_agg_scenario_6():

    true_named_entities = [{"label": "PER", "start": 59, "end": 69}]

    pred_named_entities = [{"label": "LOC", "start": 54, "end": 69}]

    results, results_agg = compute_metrics(true_named_entities,
                                           pred_named_entities, ["PER", "LOC"])

    expected_agg = {
        "PER": {
            "strict": {
                "correct": 0,
                "incorrect": 1,
                "partial": 0,
                "missed": 0,
                "spurious": 0,
                "actual": 1,
                "possible": 1,
                "precision": 0,
                "recall": 0,
                "f1": 0,
            },
            "ent_type": {
                "correct": 0,
                "incorrect": 1,
                "partial": 0,
                "missed": 0,
                "spurious": 0,
                "actual": 1,
                "possible": 1,
                "precision": 0,
                "recall": 0,
                "f1": 0,
            },
            "partial": {
                "correct": 0,
                "incorrect": 0,
                "partial": 1,
                "missed": 0,
                "spurious": 0,
                "actual": 1,
                "possible": 1,
                "precision": 0,
                "recall": 0,
                "f1": 0,
            },
            "exact": {
                "correct": 0,
                "incorrect": 1,
                "partial": 0,
                "missed": 0,
                "spurious": 0,
                "actual": 1,
                "possible": 1,
                "precision": 0,
                "recall": 0,
                "f1": 0,
            },
        },
        "LOC": {
            "strict": {
                "correct": 0,
                "incorrect": 0,
                "partial": 0,
                "missed": 0,
                "spurious": 0,
                "actual": 0,
                "possible": 0,
                "precision": 0,
                "recall": 0,
                "f1": 0,
            },
            "ent_type": {
                "correct": 0,
                "incorrect": 0,
                "partial": 0,
                "missed": 0,
                "spurious": 0,
                "actual": 0,
                "possible": 0,
                "precision": 0,
                "recall": 0,
                "f1": 0,
            },
            "partial": {
                "correct": 0,
                "incorrect": 0,
                "partial": 0,
                "missed": 0,
                "spurious": 0,
                "actual": 0,
                "possible": 0,
                "precision": 0,
                "recall": 0,
                "f1": 0,
            },
            "exact": {
                "correct": 0,
                "incorrect": 0,
                "partial": 0,
                "missed": 0,
                "spurious": 0,
                "actual": 0,
                "possible": 0,
                "precision": 0,
                "recall": 0,
                "f1": 0,
            },
        },
    }

    assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"]
    assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"]
    assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"]
    assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"]

    assert results_agg["LOC"] == expected_agg["LOC"]
예제 #7
0
def test_compute_metrics_case_1():
    true_named_entities = [
        {
            "label": "PER",
            "start": 59,
            "end": 69
        },
        {
            "label": "LOC",
            "start": 127,
            "end": 134
        },
        {
            "label": "LOC",
            "start": 164,
            "end": 174
        },
        {
            "label": "LOC",
            "start": 197,
            "end": 205
        },
        {
            "label": "LOC",
            "start": 208,
            "end": 219
        },
        {
            "label": "MISC",
            "start": 230,
            "end": 240
        },
    ]

    pred_named_entities = [
        {
            "label": "PER",
            "start": 24,
            "end": 30
        },
        {
            "label": "LOC",
            "start": 124,
            "end": 134
        },
        {
            "label": "PER",
            "start": 164,
            "end": 174
        },
        {
            "label": "LOC",
            "start": 197,
            "end": 205
        },
        {
            "label": "LOC",
            "start": 208,
            "end": 219
        },
        {
            "label": "LOC",
            "start": 225,
            "end": 243
        },
    ]

    results, results_agg = compute_metrics(true_named_entities,
                                           pred_named_entities,
                                           ["PER", "LOC", "MISC"])

    results = compute_precision_recall_wrapper(results)

    expected = {
        "strict": {
            "correct": 2,
            "incorrect": 3,
            "partial": 0,
            "missed": 1,
            "spurious": 1,
            "possible": 6,
            "actual": 6,
            "precision": 0.3333333333333333,
            "recall": 0.3333333333333333,
            "f1": 0.3333333333333333,
        },
        "ent_type": {
            "correct": 3,
            "incorrect": 2,
            "partial": 0,
            "missed": 1,
            "spurious": 1,
            "possible": 6,
            "actual": 6,
            "precision": 0.5,
            "recall": 0.5,
            "f1": 0.5,
        },
        "partial": {
            "correct": 3,
            "incorrect": 0,
            "partial": 2,
            "missed": 1,
            "spurious": 1,
            "possible": 6,
            "actual": 6,
            "precision": 0.6666666666666666,
            "recall": 0.6666666666666666,
            "f1": 0.6666666666666666,
        },
        "exact": {
            "correct": 3,
            "incorrect": 2,
            "partial": 0,
            "missed": 1,
            "spurious": 1,
            "possible": 6,
            "actual": 6,
            "precision": 0.5,
            "recall": 0.5,
            "f1": 0.5,
        },
    }

    assert results == expected