def dev_tuner(full_csv, split_type): """Tuning n_components on dev set. Note: This is a very basic tunning for nn based affinity. This is work in progress till we find a better way. """ DER_list = [] pval = None for n_lambdas in range(1, params["max_num_spkrs"] + 1): # Process whole dataset for value of n_lambdas concate_rttm_file = diarize_dataset(full_csv, split_type, n_lambdas, pval) ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm") sys_rttm = concate_rttm_file [MS, FA, SER, DER_] = DER( ref_rttm, sys_rttm, params["ignore_overlap"], params["forgiveness_collar"], ) DER_list.append(DER_) # Take n_lambdas with minmum DER tuned_n_lambdas = DER_list.index(min(DER_list)) + 1 return tuned_n_lambdas
def dev_p_tuner(full_csv, split_type): """Tuning p_value affinity matrix """ DER_list = [] prange = np.arange(0.002, 0.015, 0.001) n_lambdas = None for p_v in prange: # Process whole dataset for value of p_v concate_rttm_file = diarize_dataset(full_csv, split_type, n_lambdas, p_v) ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm") sys_rttm = concate_rttm_file [MS, FA, SER, DER_] = DER( ref_rttm, sys_rttm, params["ignore_overlap"], params["forgiveness_collar"], ) DER_list.append(DER_) # Take p_val that gave minmum DER on Dev dataset tuned_p_val = prange[DER_list.index(min(DER_list))] return tuned_p_val
def dev_nn_tuner(full_csv, split_type): """Tuning n_neighbors on dev set. Assuming oracle num of speakers. """ DER_list = [] pval = None for nn in range(5, 15): # Fix this later. Now assumming oracle num of speakers n_lambdas = 4 # Process whole dataset for value of n_lambdas concate_rttm_file = diarize_dataset(full_csv, split_type, n_lambdas, pval, nn) ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm") sys_rttm = concate_rttm_file [MS, FA, SER, DER_] = DER( ref_rttm, sys_rttm, params["ignore_overlap"], params["forgiveness_collar"], ) DER_list.append([nn, DER_]) DER_list.sort(key=lambda x: x[1]) tunned_nn = DER_list[0] return tunned_nn[0]
def dev_nn_tuner(full_meta, split_type): """Tuning n_neighbors on dev set. Assuming oracle num of speakers. This is used when nn based affinity is selected. """ DER_list = [] pval = None # Now assumming oracle num of speakers. n_lambdas = 4 for nn in range(5, 15): # Process whole dataset for value of n_lambdas. concate_rttm_file = diarize_dataset(full_meta, split_type, n_lambdas, pval, nn) ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm") sys_rttm = concate_rttm_file [MS, FA, SER, DER_] = DER( ref_rttm, sys_rttm, params["ignore_overlap"], params["forgiveness_collar"], ) DER_list.append([nn, DER_]) if params["oracle_n_spkrs"] is True and params["backend"] == "kmeans": break DER_list.sort(key=lambda x: x[1]) tunned_nn = DER_list[0] return tunned_nn[0]
def dev_ahc_threshold_tuner(full_meta, split_type): """Tuning threshold for affinity matrix. This function is called when AHC is used as backend. """ DER_list = [] prange = np.arange(0.0, 1.0, 0.1) n_lambdas = None # using it as flag later. # Note: p_val is threshold in case of AHC. for p_v in prange: # Process whole dataset for value of p_v. concate_rttm_file = diarize_dataset(full_meta, split_type, n_lambdas, p_v) ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm") sys_rttm = concate_rttm_file [MS, FA, SER, DER_] = DER( ref_rttm, sys_rttm, params["ignore_overlap"], params["forgiveness_collar"], ) DER_list.append(DER_) if params["oracle_n_spkrs"] is True: break # no need of threshold search. # Take p_val that gave minmum DER on Dev dataset. tuned_p_val = prange[DER_list.index(min(DER_list))] return tuned_p_val
def dev_pval_tuner(full_meta, split_type): """Tuning p_value for affinity matrix. The p_value used so that only p% of the values in each row is retained. """ DER_list = [] prange = np.arange(0.002, 0.015, 0.001) n_lambdas = None # using it as flag later. for p_v in prange: # Process whole dataset for value of p_v. concate_rttm_file = diarize_dataset(full_meta, split_type, n_lambdas, p_v) ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm") sys_rttm = concate_rttm_file [MS, FA, SER, DER_] = DER( ref_rttm, sys_rttm, params["ignore_overlap"], params["forgiveness_collar"], ) DER_list.append(DER_) if params["oracle_n_spkrs"] is True and params["backend"] == "kmeans": # no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers. # p_val is needed in oracle_n_spkr=False when using kmeans backend. break # Take p_val that gave minmum DER on Dev dataset. tuned_p_val = prange[DER_list.index(min(DER_list))] return tuned_p_val
out_boundaries = diarize_dataset( full_csv, "dev", n_lambdas=n_lambdas, pval=best_pval, n_neighbors=best_nn, ) # Evaluating on DEV set logger.info("Evaluating for AMI Dev. set") ref_rttm_dev = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm") sys_rttm_dev = out_boundaries [MS_dev, FA_dev, SER_dev, DER_dev] = DER( ref_rttm_dev, sys_rttm_dev, params["ignore_overlap"], params["forgiveness_collar"], individual_file_scores=True, ) msg = "AMI Dev set: Diarization Error Rate = %s %%\n" % (str( round(DER_dev[-1], 2))) logger.info(msg) # AMI Eval Set full_csv = [] with open(params["csv_diary_eval"], "r") as csv_file: reader = csv.reader(csv_file, delimiter=",") for row in reader: full_csv.append(row) out_boundaries = diarize_dataset(
split_type, n_lambdas=n_lambdas, pval=best_pval, n_neighbors=best_nn, ) # Computing DER. msg = "Computing DERs for " + split_type + " set" logger.info(msg) ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_" + split_type + ".rttm") sys_rttm = out_boundaries [MS, FA, SER, DER_vals] = DER( ref_rttm, sys_rttm, params["ignore_overlap"], params["forgiveness_collar"], individual_file_scores=True, ) # Writing DER values to a file. Append tag. der_file_name = split_type + "_DER_" + tag out_der_file = os.path.join(params["der_dir"], der_file_name) msg = "Writing DER file to: " + out_der_file logger.info(msg) diar.write_ders_file(ref_rttm, DER_vals, out_der_file) msg = ("AMI " + split_type + " set DER = %s %%\n" % (str(round(DER_vals[-1], 2)))) logger.info(msg) final_DERs[split_type] = round(DER_vals[-1], 2)