def vad_tune_threshold_on_dev( params: dict, vad_pred: str, groundtruth_RTTM: str, result_file: str = "res", vad_pred_method: str = "frame", focus_metric: str = "DetER", shift_length_in_sec: float = 0.01, num_workers: int = 20, ) -> Tuple[dict, dict]: """ Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds. Args: params (dict): dictionary of parameters to be tuned on. vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median". groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them. focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" Returns: best_threshold (float): threshold that gives lowest DetER. """ min_score = 100 all_perf = {} try: check_if_param_valid(params) except: raise ValueError("Please check if the parameters are valid") paired_filenames, groundtruth_RTTM_dict, vad_pred_dict = pred_rttm_map(vad_pred, groundtruth_RTTM, vad_pred_method) metric = detection.DetectionErrorRate() params_grid = get_parameter_grid(params) for param in params_grid: for i in param: if type(param[i]) == np.float64 or type(param[i]) == np.int64: param[i] = float(param[i]) try: # Generate speech segments by performing binarization on the VAD prediction according to param. # Filter speech segments according to param and write the result to rttm-like table. vad_table_dir = generate_vad_segment_table( vad_pred, param, shift_length_in_sec=shift_length_in_sec, num_workers=num_workers ) # add reference and hypothesis to metrics for filename in paired_filenames: groundtruth_RTTM_file = groundtruth_RTTM_dict[filename] vad_table_filepath = os.path.join(vad_table_dir, filename + ".txt") reference, hypothesis = vad_construct_pyannote_object_per_file( vad_table_filepath, groundtruth_RTTM_file ) metric(reference, hypothesis) # accumulation # delete tmp table files shutil.rmtree(vad_table_dir, ignore_errors=True) report = metric.report(display=False) DetER = report.iloc[[-1]][('detection error rate', '%')].item() FA = report.iloc[[-1]][('false alarm', '%')].item() MISS = report.iloc[[-1]][('miss', '%')].item() assert ( focus_metric == "DetER" or focus_metric == "FA" or focus_metric == "MISS" ), "Metric we care most should be only in 'DetER', 'FA' or 'MISS'!" all_perf[str(param)] = {'DetER (%)': DetER, 'FA (%)': FA, 'MISS (%)': MISS} logging.info(f"parameter {param}, {all_perf[str(param)] }") score = all_perf[str(param)][focus_metric + ' (%)'] del report metric.reset() # reset internal accumulator # save results for analysis with open(result_file + ".txt", "a", encoding='utf-8') as fp: fp.write(f"{param}, {all_perf[str(param)] }\n") if score < min_score: best_threshold = param optimal_scores = all_perf[str(param)] min_score = score print("Current best", best_threshold, optimal_scores) except RuntimeError as e: print(f"Pass {param}, with error {e}") except pd.errors.EmptyDataError as e1: print(f"Pass {param}, with error {e1}") return best_threshold, optimal_scores
def main(): parser = argparse.ArgumentParser( description= "Scripts that computes metrics between reference and hypothesis files." "Inputs can be both path to folders or single file.") parser.add_argument('-ref', '--reference', type=str, required=True, help="Path of the reference.") parser.add_argument( '-hyp', '--hypothesis', type=str, required=False, default=None, help="Path of the hypothesis" "If None, consider that the hypothesis is stored where the reference is." ) parser.add_argument( '-p', '--prefix', required=False, default="", choices=[ "lena", "noisemesSad", "opensmileSad", "tocomboSad", "yunitator_old", "yunitator_english", "yunitator_universal", "diartk_noisemesSad", "diartk_tocomboSad", "diartk_opensmileSad", "diartk_goldSad", "yuniseg_noisemesSad", "yuniseg_opensmileSad", "yuniseg_tocomboSad", "yuniseg_goldSad" ], help="Prefix that filenames of the hypothesis must match.") parser.add_argument('-t', '--task', type=str, required=True, choices=["detection", "diarization", "identification"]) parser.add_argument('-m', '--metrics', required=True, nargs='+', type=str, choices=[ "diaer", "coverage", "completeness", "homogeneity", "purity", "accuracy", "precision", "recall", "deter", "ider", "idea" ], help="Metrics that need to be run.") parser.add_argument('--visualization', action='store_true') parser.add_argument('--identification', action='store_true') parser.add_argument( '--class_to_keep', default=None, choices=['OCH', 'CHI', 'ELE', 'FEM', 'MAL', 'OVL', 'SIL'], help= "If not None, will only keep labels corresponding to the specified class." ) args = parser.parse_args() class_to_keep = args.class_to_keep if args.identification: args.task = "identification" # Let's create the metrics metrics = {} for m in args.metrics: if m == "accuracy": # All the 3 tasks can be evaluated as a detection task metrics[m] = detection.DetectionAccuracy(parallel=True) elif m == "precision": metrics[m] = detection.DetectionPrecision(parallel=True) elif m == "recall": metrics[m] = detection.DetectionRecall(parallel=True) elif m == "deter": metrics[m] = detection.DetectionErrorRate(parallel=True) elif args.task == "diarization" or args.task == "identification": # The diarization and the identification task can be both evaluated as a diarization task if m == "diaer": metrics[m] = diarization.DiarizationErrorRate(parallel=True) elif m == "coverage": metrics[m] = diarization.DiarizationCoverage(parallel=True) elif m == "completeness": metrics[m] = diarization.DiarizationCompleteness(parallel=True) elif m == "homogeneity": metrics[m] = diarization.DiarizationHomogeneity(parallel=True) elif m == "purity": metrics[m] = diarization.DiarizationPurity(parallel=True) elif args.task == "identification": # Only the identification task can be evaluated as an identification task if m == "ider": metrics[m] = identification.IdentificationErrorRate( parallel=True) elif m == "precision": metrics[m] = identification.IdentificationPrecision( parallel=True) elif m == "recall": metrics[m] = identification.IdentificationRecall( parallel=True) else: print( "Filtering out %s, which is not available for the %s task." % (m, args.task)) else: print("Filtering out %s, which is not available for the %s task." % (m, args.task)) # Get files and run the metrics references_f, hypothesis_f = get_couple_files(args.reference, args.hypothesis, args.prefix) # print("Pairs that have been found : ") # for ref, hyp in zip(references_f, hypothesis_f): # print("%s / %s "% (os.path.basename(ref), os.path.basename(hyp))) metrics = run_metrics(references_f, hypothesis_f, metrics, args.visualization, class_to_keep) output_dir = os.path.join(args.reference, (os.path.basename(args.reference) + '_' + os.path.basename(args.hypothesis)).replace( 'mapped_', '')) if not os.path.isdir(output_dir): os.mkdir(output_dir) # Display a report for each metrics for name, m in metrics.items(): #print("\n%s report" % name) #print(m) rep = m.report(display=False) colnames = list(rep.columns.get_level_values(0)) percent_or_count = rep.columns.get_level_values(1) for i in range(0, len(percent_or_count)): if percent_or_count[i] == '%': colnames[i] = colnames[i] + ' %' rep.columns = colnames if args.prefix != "": dest_output = os.path.join( output_dir, name + '_' + args.prefix + "_report.csv") else: dest_output = os.path.join(output_dir, name + "_report.csv") if class_to_keep is not None: dest_output = os.path.join(os.path.dirname(dest_output), ("only_%s_" % class_to_keep) + os.path.basename(dest_output)) rep.to_csv(dest_output, float_format="%.2f") print("Done computing metrics between %s and %s." % (args.reference, args.hypothesis))
def vad_tune_threshold_on_dev(thresholds, vad_pred_method, vad_pred_dir, groundtruth_RTTM_dir, focus_metric="DetER"): """ Tune threshold on dev set. Return best threshold which gives the lowest detection error rate (DetER) in thresholds. Args: thresholds (list): list of thresholds. vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median". vad_pred_dir (str): directory of vad predictions. groundtruth_RTTM_dir (str): directory of groundtruch rttm files. focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" Returns: best_threhsold (float): threshold that gives lowest DetER. """ threshold_perf = {} best_threhsold = thresholds[0] min_score = 100 try: thresholds[0] >= 0 and thresholds[-1] <= 1 except: raise ValueError("Invalid threshold! Should be in [0, 1]") for threshold in thresholds: metric = detection.DetectionErrorRate() filenames = [ os.path.basename(f).split(".")[0] for f in glob.glob(os.path.join(groundtruth_RTTM_dir, "*.rttm")) ] for filename in filenames: vad_pred_filepath = os.path.join(vad_pred_dir, filename + "." + vad_pred_method) table_out_dir = os.path.join(vad_pred_dir, "table_output_" + str(threshold)) if not os.path.exists(table_out_dir): os.mkdir(table_out_dir) per_args = { "threshold": threshold, "shift_len": 0.01, "out_dir": table_out_dir } vad_table_filepath = generate_vad_segment_table_per_file( vad_pred_filepath, per_args) groundtruth_RTTM_file = os.path.join(groundtruth_RTTM_dir, filename + '.rttm') reference, hypothesis = vad_construct_pyannote_object_per_file( vad_table_filepath, groundtruth_RTTM_file) metric(reference, hypothesis) # accumulation report = metric.report(display=False) DetER = report.iloc[[-1]][('detection error rate', '%')].item() FA = report.iloc[[-1]][('false alarm', '%')].item() MISS = report.iloc[[-1]][('miss', '%')].item() if focus_metric == "DetER": score = DetER elif focus_metric == "FA": score = FA elif focus_metric == "MISS": score = MISS else: raise ValueError( "Metric we care most should be only in 'DetER', 'FA'or 'MISS'!" ) threshold_perf[threshold] = { 'DetER (%)': DetER, 'FA (%)': FA, 'MISS (%)': MISS } logging.info(f"threshold {threshold}, {threshold_perf[threshold]}") del report metric.reset() # reset internal accumulator if score < min_score: min_score = score best_threhsold = threshold return best_threhsold