def test_compute_metrics_agg_scenario_3(): true_named_entities = [Entity('PER', 59, 69)] pred_named_entities = [] results, results_agg = compute_metrics(true_named_entities, pred_named_entities) expected_agg = { 'PER': { 'strict': { 'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 1, 'spurious': 0, 'actual': 0, 'possible': 1, }, 'ent_type': { 'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 1, 'spurious': 0, 'actual': 0, 'possible': 1, }, 'partial': { 'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 1, 'spurious': 0, 'actual': 0, 'possible': 1, }, 'exact': { 'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 1, 'spurious': 0, 'actual': 0, 'possible': 1, } } } assert results_agg['PER']['strict'] == expected_agg['PER']['strict'] assert results_agg['PER']['ent_type'] == expected_agg['PER']['ent_type'] assert results_agg['PER']['partial'] == expected_agg['PER']['partial'] assert results_agg['PER']['exact'] == expected_agg['PER']['exact']
} # overall results results = { 'strict': deepcopy(metrics_results), 'ent_type': deepcopy(metrics_results), 'partial': deepcopy(metrics_results), 'exact': deepcopy(metrics_results) } # results aggregated by entity type evaluation_agg_entities_type = {e: deepcopy(results) for e in ['Symptom']} for true_ents, pred_ents in zip(y_test_, y_pred_): tmp_results, tmp_agg_results = compute_metrics( collect_named_entities(true_ents), collect_named_entities(pred_ents), ['Symptom']) for eval_schema in results.keys(): for metric in metrics_results.keys(): results[eval_schema][metric] += tmp_results[eval_schema][metric] # Calculate global precision and recall results = compute_precision_recall_wrapper(results) # aggregate results by entity type for e_type in ['Symptom']: for eval_schema in tmp_agg_results[e_type]:
def test_compute_metrics_no_predictions(): true_named_entities = [ Entity('PER', 50, 52), Entity('ORG', 59, 69), Entity('MISC', 71, 72), ] pred_named_entities = [] results, results_agg = compute_metrics(true_named_entities, pred_named_entities, ['PER', 'ORG', 'MISC']) expected = { 'strict': { 'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 3, 'spurious': 0, 'actual': 0, 'possible': 3, 'precision': 0, 'recall': 0, }, 'ent_type': { 'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 3, 'spurious': 0, 'actual': 0, 'possible': 3, 'precision': 0, 'recall': 0, }, 'partial': { 'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 3, 'spurious': 0, 'actual': 0, 'possible': 3, 'precision': 0, 'recall': 0, }, 'exact': { 'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 3, 'spurious': 0, 'actual': 0, 'possible': 3, 'precision': 0, 'recall': 0, } } assert results['strict'] == expected['strict'] assert results['ent_type'] == expected['ent_type'] assert results['partial'] == expected['partial'] assert results['exact'] == expected['exact']
def test_compute_metrics_extra_tags_in_true(): true_named_entities = [ Entity('PER', 50, 52), Entity('ORG', 59, 69), Entity('MISC', 71, 72), ] pred_named_entities = [ Entity('LOC', 50, 52), # Wrong type Entity('ORG', 59, 69), # Correct Entity('ORG', 71, 72), # Spurious ] results, results_agg = compute_metrics(true_named_entities, pred_named_entities, ['PER', 'LOC', 'ORG']) expected = { 'strict': { 'correct': 1, 'incorrect': 1, 'partial': 0, 'missed': 0, 'spurious': 1, 'actual': 3, 'possible': 2, 'precision': 0, 'recall': 0, }, 'ent_type': { 'correct': 1, 'incorrect': 1, 'partial': 0, 'missed': 0, 'spurious': 1, 'actual': 3, 'possible': 2, 'precision': 0, 'recall': 0, }, 'partial': { 'correct': 2, 'incorrect': 0, 'partial': 0, 'missed': 0, 'spurious': 1, 'actual': 3, 'possible': 2, 'precision': 0, 'recall': 0, }, 'exact': { 'correct': 2, 'incorrect': 0, 'partial': 0, 'missed': 0, 'spurious': 1, 'actual': 3, 'possible': 2, 'precision': 0, 'recall': 0, } } assert results['strict'] == expected['strict'] assert results['ent_type'] == expected['ent_type'] assert results['partial'] == expected['partial'] assert results['exact'] == expected['exact']
def test_compute_metrics_case_1(): true_named_entities = [ Entity('PER', 59, 69), Entity('LOC', 127, 134), Entity('LOC', 164, 174), Entity('LOC', 197, 205), Entity('LOC', 208, 219), Entity('MISC', 230, 240) ] pred_named_entities = [ Entity('PER', 24, 30), Entity('LOC', 124, 134), Entity('PER', 164, 174), Entity('LOC', 197, 205), Entity('LOC', 208, 219), Entity('LOC', 225, 243) ] results, results_agg = compute_metrics(true_named_entities, pred_named_entities, ['PER', 'LOC', 'MISC']) results = compute_precision_recall_wrapper(results) expected = { 'strict': { 'correct': 2, 'incorrect': 3, 'partial': 0, 'missed': 1, 'spurious': 1, 'possible': 6, 'actual': 6, 'precision': 0.3333333333333333, 'recall': 0.3333333333333333 }, 'ent_type': { 'correct': 3, 'incorrect': 2, 'partial': 0, 'missed': 1, 'spurious': 1, 'possible': 6, 'actual': 6, 'precision': 0.5, 'recall': 0.5 }, 'partial': { 'correct': 3, 'incorrect': 0, 'partial': 2, 'missed': 1, 'spurious': 1, 'possible': 6, 'actual': 6, 'precision': 0.6666666666666666, 'recall': 0.6666666666666666 }, 'exact': { 'correct': 3, 'incorrect': 2, 'partial': 0, 'missed': 1, 'spurious': 1, 'possible': 6, 'actual': 6, 'precision': 0.5, 'recall': 0.5 } } assert results == expected