def get_identification_metrics(reference, hypothesis, uem=None): metric_dict = {} metric = IdentificationErrorRate() met = metric(reference, hypothesis, uem=uem) metric_dict[metric.metric_name()] = met metric = IdentificationPrecision() met = metric(reference, hypothesis, uem=uem) metric_dict[metric.metric_name()] = met metric = IdentificationRecall() met = metric(reference, hypothesis, uem=uem) metric_dict[metric.metric_name()] = met return metric_dict
def identification(protocol, subset, hypotheses, collar=0.0, skip_overlap=False): options = { 'collar': collar, 'skip_overlap': skip_overlap, 'parallel': True } metrics = { 'error': IdentificationErrorRate(**options), 'precision': IdentificationPrecision(**options), 'recall': IdentificationRecall(**options) } reports = get_reports(protocol, subset, hypotheses, metrics) report = metrics['error'].report(display=False) precision = metrics['precision'].report(display=False) recall = metrics['recall'].report(display=False) report['precision', '%'] = precision[metrics['precision'].name, '%'] report['recall', '%'] = recall[metrics['recall'].name, '%'] columns = list(report.columns) report = report[[columns[0]] + columns[-2:] + columns[1:-2]] report = reindex(report) summary = 'Identification (collar = {1:g} ms{2})'.format( 1000 * collar, ', no overlap' if skip_overlap else '') headers = [summary] + \ [report.columns[i][0] for i in range(3)] + \ ['%' if c[1] == '%' else c[0] for c in report.columns[3:]] print( tabulate(report, headers=headers, tablefmt="simple", floatfmt=".2f", numalign="decimal", stralign="left", missingval="", showindex="default", disable_numparse=False))
def test_detailed(reference, hypothesis): identificationErrorRate = IdentificationErrorRate() details = identificationErrorRate(reference, hypothesis, detailed=True) confusion = details['confusion'] npt.assert_almost_equal(confusion, 7.0, decimal=7) correct = details['correct'] npt.assert_almost_equal(correct, 22.0, decimal=7) rate = details['identification error rate'] npt.assert_almost_equal(rate, 0.5161290322580645, decimal=7) false_alarm = details['false alarm'] npt.assert_almost_equal(false_alarm, 7.0, decimal=7) missed_detection = details['missed detection'] npt.assert_almost_equal(missed_detection, 2.0, decimal=7) total = details['total'] npt.assert_almost_equal(total, 31.0, decimal=7)
def test_error_rate(reference, hypothesis): identificationErrorRate = IdentificationErrorRate() error_rate = identificationErrorRate(reference, hypothesis) npt.assert_almost_equal(error_rate, 0.5161290322580645, decimal=7)
def main(reference_dir, hypothesis_dir, output_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) flist = os.listdir(reference_dir) total_references = len(flist) total_hypotheses = len(os.listdir(hypothesis_dir)) if total_references == 0: # no references available score_f = os.path.join(output_dir, 'score.seconds') score = open(score_f, 'w') score.write('No references available.\n') score.write('references {0}\n'.format(total_references)) score.write('hypotheses {0}\n'.format(total_hypotheses)) sys.exit(0) collar = 0.1 # collar in seconds der_eval = DiarizationErrorRate(collar=collar) ier_eval = IdentificationErrorRate(collar=collar) prec_eval = IdentificationPrecision(collar=collar) rec_eval = IdentificationRecall(collar=collar) skip_tokens = ['OVERLAP', 'SPN'] skip_tokens_child = ['OVERLAP', 'SPN', 'SLT'] missing_hypotheses = 0 missing_hypotheses_seconds = 0 utt_scores = [] for f in flist: ref_f = os.path.join(reference_dir, f) hyp_f = os.path.join(hypothesis_dir, f) reference = read_annotation(ref_f, \ annotation_type='reference', skip_tokens=skip_tokens) reference_child = read_annotation(ref_f, \ annotation_type='reference', skip_tokens=skip_tokens_child) if not os.path.isfile(hyp_f): missing_hypotheses += 1 missed_sum = sum( [i.end - i.start for i in reference.itersegments()]) missing_hypotheses_seconds += missed_sum # read_annotation can handle non-existing files hypothesis = read_annotation(hyp_f, \ annotation_type='hypothesis', skip_tokens=skip_tokens) hypothesis_child = read_annotation(hyp_f, \ annotation_type='hypothesis', skip_tokens=skip_tokens_child) # find global min and max time_ref = [[i.start, i.end] for i in reference.itersegments()] time_hyp = [[i.start, i.end] for i in hypothesis.itersegments()] min_f = min([i for i, e in time_hyp] + [i for i, e in time_ref]) max_f = max([e for i, e in time_hyp] + [e for i, e in time_ref]) # evaluate DER der = der_eval(reference, hypothesis, \ uem=Segment(min_f, max_f), detailed=True) # find global min and max time_ref = [[i.start, i.end] for i in reference_child.itersegments()] time_hyp = [[i.start, i.end] for i in hypothesis_child.itersegments()] min_f = min([i for i, e in time_hyp] + [i for i, e in time_ref]) max_f = max([e for i, e in time_hyp] + [e for i, e in time_ref]) # evaluate IER ier = ier_eval(reference_child, hypothesis_child, \ uem=Segment(min_f, max_f), detailed=True) prec = prec_eval(reference_child, hypothesis_child, \ uem=Segment(min_f, max_f)) rec = rec_eval(reference_child, hypothesis_child, \ uem=Segment(min_f, max_f)) f1 = 0 if prec == 0 or rec == 0 else 2 * (prec * rec) / (prec + rec) ref_labs = ' '.join(reference.labels()) hyp_labs = ' '.join(hypothesis.labels()) ref_labs = ' '.join( [label for _, _, label in reference.itertracks(yield_label=True)]) hyp_labs = ' '.join( [label for _, _, label in hypothesis.itertracks(yield_label=True)]) if not hyp_labs: hyp_labs = 'no_alignment' utt_scores.append([f, prec, rec, f1, der, ier, ref_labs, hyp_labs]) # global scores ier = abs(ier_eval) der = abs(der_eval) precision = abs(prec_eval) recall = abs(rec_eval) f1 = 0 if precision == 0 or recall == 0 else 2 * (precision * recall) / ( precision + recall) # keys to intermediate metrics keys = ['correct', 'missed detection', 'false alarm', \ 'confusion', 'total', 'diarization error rate'] aggregate = {k: 0 for k in keys} ## global correct, missed, false alarm, confusion for item in utt_scores: der_errors = item[4] for key in keys: aggregate[key] += der_errors[key] ier_errors = item[5] item_ier = ier_errors['identification error rate'] aggregate['der'] = item_ier if aggregate['total'] == 0: aggregate['total'] = 1 # write global scores to file score_f = os.path.join(output_dir, 'score.seconds') score = open(score_f, 'w') score.write('precision {0:.3f}\n'.format(precision)) score.write('recall {0:.3f}\n'.format(recall)) score.write('f1 score {0:.3f}\n\n'.format(f1)) score.write('IER {0:.3f}\n\n'.format(ier)) score.write('DER {0:.3f}\n'.format(der)) score.write(' missed {0:.3f}\n'.format(aggregate['missed detection'] / aggregate['total'])) score.write(' false alarm {0:.3f}\n'.format(aggregate['false alarm'] / aggregate['total'])) score.write(' confusion {0:.3f}\n'.format(aggregate['confusion'] / aggregate['total'])) score.write(' correct {0:.3f}\n'.format(aggregate['correct'] / aggregate['total'])) score.write('\n') score.write('total files {0}\n'.format(total_references)) score.write('alignment failures\n') score.write(' total utterances: {0}\n'.format(missing_hypotheses)) score.write(' total seconds in failed utterances: {0}\n\n'.format( missing_hypotheses_seconds)) score.write('precision details\n') for i in prec_eval[:]: score.write(' {0} {1}\n'.format(i, prec_eval[:][i])) score.write('\n') score.write('recall details\n') for i in rec_eval[:]: score.write(' {0} {1}\n'.format(i, rec_eval[:][i])) score.close() # write detailed scores to file sorted by DER # columns: filename, precision, recall, f1, reference_words, hypothesis_words report_f = os.path.join(output_dir, 'report.seconds') report = open(report_f, 'w') header = [ 'filename', 'precision', 'recall', 'f1', 'correct', 'missed', 'false_alarm', 'confusion', 'total', 'der', 'ier', 'reference_words', 'hypothesis_words' ] report.write('\t'.join(header) + '\n') for item in sorted(utt_scores, key=lambda x: x[4]['diarization error rate']): data = [] # filename data.append(item[0]) # precision, recall, f1 for i in range(1, 3 + 1): data.append('{0:.3f}'.format(item[i])) # DER related scores errors = item[4] for key in keys: value = '{0:.3f}'.format(errors[key]) data.append(value) # IER score ier = item[5]['identification error rate'] data.append('{0:.3f}'.format(ier)) data.append(item[-2]) data.append(item[-1]) report.write('\t'.join(data) + '\n') report.close()