def test_compare_frames(self) -> None: i = 0 for example in TEST_EXAMPLES: self.assertEqual( compare_frames(example["predicted"], example["expected"], tree_based=False), IntentSlotConfusions( intent_confusions=Confusions( **example["bracket_confusions"]["intent_confusion"]), slot_confusions=Confusions( **example["bracket_confusions"]["slot_confusion"]), ), i, ) self.assertEqual( compare_frames(example["predicted"], example["expected"], tree_based=True), IntentSlotConfusions( intent_confusions=Confusions( **example["tree_confusions"]["intent_confusion"]), slot_confusions=Confusions( **example["tree_confusions"]["slot_confusion"]), ), ) i += 1
def calculate_metric(self) -> PRF1Metrics: all_confusions = AllConfusions() for pred, expect in zip(self.all_preds, self.all_targets): pred_seq, expect_seq = [], [] for p, e in zip(pred, expect): if e != self.pad_idx: pred_seq.append(self.label_names[p]) expect_seq.append(self.label_names[e]) expect_spans = convert_bio_to_spans(expect_seq) pred_spans = convert_bio_to_spans(pred_seq) expect_spans_set = set(expect_spans) pred_spans_set = set(pred_spans) true_positive = expect_spans_set & pred_spans_set false_positive = pred_spans_set - expect_spans_set false_negative = expect_spans_set - pred_spans_set all_confusions.confusions += Confusions(TP=len(true_positive), FP=len(false_positive), FN=len(false_negative)) for span in true_positive: all_confusions.per_label_confusions.update(span.label, "TP", 1) for span in false_positive: all_confusions.per_label_confusions.update(span.label, "FP", 1) for span in false_negative: all_confusions.per_label_confusions.update(span.label, "FN", 1) return all_confusions.compute_metrics()