def gather_eval_results(self, metric, mapping_dict): """ Gathers diarization evaluation results from pyannote DiarizationErrorRate metric object. Inputs metric (DiarizationErrorRate metric): DiarizationErrorRate metric pyannote object mapping_dict (dict): Dictionary containing speaker mapping labels for each audio file with key as uniq name Returns DER_result_dict (dict): Dictionary containing scores for each audio file along with aggreated results """ results = metric.results_ DER_result_dict = {} count_correct_spk_counting = 0 for result in results: key, score = result pred_rttm = os.path.join(self.root_path, 'pred_rttms', key + '.rttm') pred_labels = rttm_to_labels(pred_rttm) est_n_spk = self.get_num_of_spk_from_labels(pred_labels) ref_rttm = self.AUDIO_RTTM_MAP[key]['rttm_filepath'] ref_labels = rttm_to_labels(ref_rttm) ref_n_spk = self.get_num_of_spk_from_labels(ref_labels) DER, CER, FA, MISS = ( score['diarization error rate'], score['confusion'], score['false alarm'], score['missed detection'], ) DER_result_dict[key] = { "DER": DER, "CER": CER, "FA": FA, "MISS": MISS, "n_spk": est_n_spk, "mapping": mapping_dict[key], "spk_counting": (est_n_spk == ref_n_spk), } logging.info("score for session {}: {}".format( key, DER_result_dict[key])) count_correct_spk_counting += int(est_n_spk == ref_n_spk) DER, CER, FA, MISS = ( abs(metric), metric['confusion'] / metric['total'], metric['false alarm'] / metric['total'], metric['missed detection'] / metric['total'], ) DER_result_dict["total"] = { "DER": DER, "CER": CER, "FA": FA, "MISS": MISS, "spk_counting_acc": count_correct_spk_counting / len(metric.results_), } return DER_result_dict
def get_diarization_labels(self, audio_file_list): """ Save the diarization labels into a list. Arg: audio_file_list (list): The list of audio file paths. """ diar_labels = [] for k, audio_file_path in enumerate(audio_file_list): uniq_id = get_uniq_id_from_audio_path(audio_file_path) pred_rttm = os.path.join(self.oracle_vad_dir, 'pred_rttms', uniq_id + '.rttm') pred_labels = rttm_to_labels(pred_rttm) diar_labels.append(pred_labels) est_n_spk = self.get_num_of_spk_from_labels(pred_labels) logging.info(f"Estimated n_spk [{uniq_id}]: {est_n_spk}") return diar_labels
def run_diarization( self, diar_model_config, word_timestamps, ): """ Launch the diarization process using the given VAD timestamp (oracle_manifest). Args: word_and_timestamps (list): List containing words and word timestamps Returns: diar_hyp (dict): A dictionary containing rttm results which are indexed by a unique ID. score Tuple[pyannote object, dict]: A tuple containing pyannote metric instance and mapping dictionary between speakers in hypotheses and speakers in reference RTTM files. """ if diar_model_config.diarizer.asr.parameters.asr_based_vad: self.save_VAD_labels_list(word_timestamps) oracle_manifest = os.path.join(self.root_path, 'asr_vad_manifest.json') oracle_manifest = write_rttm2manifest(self.VAD_RTTM_MAP, oracle_manifest) diar_model_config.diarizer.vad.model_path = None diar_model_config.diarizer.vad.external_vad_manifest = oracle_manifest diar_model = ClusteringDiarizer(cfg=diar_model_config) score = diar_model.diarize() if diar_model_config.diarizer.vad.model_path is not None and not diar_model_config.diarizer.oracle_vad: self.get_frame_level_VAD( vad_processing_dir=diar_model.vad_pred_dir) diar_hyp = {} for k, audio_file_path in enumerate(self.audio_file_list): uniq_id = get_uniqname_from_filepath(audio_file_path) pred_rttm = os.path.join(self.root_path, 'pred_rttms', uniq_id + '.rttm') diar_hyp[uniq_id] = rttm_to_labels(pred_rttm) return diar_hyp, score
def main(wav_path, text_path=None, rttm_path=None, uem_path=None, ctm_path=None, manifest_filepath=None): if os.path.exists(manifest_filepath): os.remove(manifest_filepath) wav_pathlist = read_file(wav_path) wav_pathdict = get_dict_from_wavlist(wav_pathlist) len_wavs = len(wav_pathlist) uniqids = sorted(wav_pathdict.keys()) text_pathdict = get_path_dict(text_path, uniqids, len_wavs) rttm_pathdict = get_path_dict(rttm_path, uniqids, len_wavs) uem_pathdict = get_path_dict(uem_path, uniqids, len_wavs) ctm_pathdict = get_path_dict(ctm_path, uniqids, len_wavs) lines = [] for uid in uniqids: wav, text, rttm, uem, ctm = ( wav_pathdict[uid], text_pathdict[uid], rttm_pathdict[uid], uem_pathdict[uid], ctm_pathdict[uid], ) audio_line = wav.strip() if rttm is not None: rttm = rttm.strip() labels = rttm_to_labels(rttm) num_speakers = Counter([l.split()[-1] for l in labels]).keys().__len__() else: num_speakers = None if uem is not None: uem = uem.strip() if text is not None: text = open(text.strip()).readlines()[0].strip() else: text = "-" if ctm is not None: ctm = ctm.strip() meta = [{ "audio_filepath": audio_line, "offset": 0, "duration": None, "label": "infer", "text": text, "num_speakers": num_speakers, "rttm_filepath": rttm, "uem_filepath": uem, "ctm_filepath": ctm, }] lines.extend(meta) write_file(manifest_filepath, lines, range(len(lines)))
def get_WDER(self, total_riva_dict, DER_result_dict): """ Calculate word-level diarization error rate (WDER). WDER is calculated by counting the the wrongly diarized words and divided by the total number of words recognized by the ASR model. Args: total_riva_dict (dict): The dictionary that stores riva_dict(dict) indexed by uniq_id variable. DER_result_dict (dict): The dictionary that stores DER, FA, Miss, CER, mapping, the estimated number of speakers and speaker counting accuracy. ref_labels_list (list): List that contains the ground truth speaker labels for each segment. Return: wder_dict (dict): A dictionary contains WDER value for each session and total WDER. """ wder_dict = {} grand_total_word_count, grand_correct_word_count = 0, 0 for k, audio_file_path in enumerate(self.audio_file_list): uniq_id = get_uniqname_from_filepath(audio_file_path) ref_rttm = self.AUDIO_RTTM_MAP[uniq_id]['rttm_filepath'] labels = rttm_to_labels(ref_rttm) mapping_dict = DER_result_dict[uniq_id]['mapping'] words_list = total_riva_dict[uniq_id]['words'] idx, correct_word_count = 0, 0 total_word_count = len(words_list) ref_label_list = [[float(x.split()[0]), float(x.split()[1])] for x in labels] ref_label_array = np.array(ref_label_list) for wdict in words_list: speaker_label = wdict['speaker_label'] if speaker_label in mapping_dict: est_spk_label = mapping_dict[speaker_label] else: continue start_point, end_point, ref_spk_label = labels[idx].split() word_range = np.array([wdict['start_time'], wdict['end_time']]) word_range_tile = np.tile(word_range, (ref_label_array.shape[0], 1)) ovl_bool = self.isOverlapArray(ref_label_array, word_range_tile) if np.any(ovl_bool) == False: continue ovl_length = self.getOverlapRangeArray(ref_label_array, word_range_tile) if self.params['lenient_overlap_WDER']: ovl_length_list = list(ovl_length[ovl_bool]) max_ovl_sub_idx = np.where(ovl_length_list == np.max(ovl_length_list))[0] max_ovl_idx = np.where(ovl_bool == True)[0][max_ovl_sub_idx] ref_spk_labels = [x.split()[-1] for x in list(np.array(labels)[max_ovl_idx])] if est_spk_label in ref_spk_labels: correct_word_count += 1 else: max_ovl_sub_idx = np.argmax(ovl_length[ovl_bool]) max_ovl_idx = np.where(ovl_bool == True)[0][max_ovl_sub_idx] _, _, ref_spk_label = labels[max_ovl_idx].split() correct_word_count += int(est_spk_label == ref_spk_label) wder = 1 - (correct_word_count / total_word_count) grand_total_word_count += total_word_count grand_correct_word_count += correct_word_count wder_dict[uniq_id] = wder wder_dict['total'] = 1 - (grand_correct_word_count / grand_total_word_count) return wder_dict
def write_json_and_transcript( self, word_list, word_ts_list, ): """ Matches the diarization result with the ASR output. The words and the timestamps for the corresponding words are matched in a for loop. Args: diar_labels (list): List of the Diarization output labels in str. word_list (list): List of words from ASR inference. word_ts_list (list): Contains word_ts_stt_end lists. word_ts_stt_end = [stt, end] stt: Start of the word in sec. end: End of the word in sec. Return: total_riva_dict (dict): A dictionary contains word timestamps, speaker labels and words. """ total_riva_dict = {} if self.fix_word_ts_with_VAD: word_ts_list = self.compensate_word_ts_list(self.audio_file_list, word_ts_list, self.params) if self.frame_VAD == {}: logging.info( f"VAD timestamps are not provided and skipping word timestamp fix. Please check the VAD model." ) for k, audio_file_path in enumerate(self.audio_file_list): uniq_id = get_uniqname_from_filepath(audio_file_path) pred_rttm = os.path.join(self.root_path, 'pred_rttms', uniq_id + '.rttm') labels = rttm_to_labels(pred_rttm) audacity_label_words = [] n_spk = self.get_num_of_spk_from_labels(labels) string_out = '' riva_dict = od( { 'status': 'Success', 'session_id': uniq_id, 'transcription': ' '.join(word_list[k]), 'speaker_count': n_spk, 'words': [], } ) start_point, end_point, speaker = labels[0].split() words = word_list[k] logging.info(f"Creating results for Session: {uniq_id} n_spk: {n_spk} ") string_out = self.print_time(string_out, speaker, start_point, end_point, self.params) word_pos, idx = 0, 0 for j, word_ts_stt_end in enumerate(word_ts_list[k]): word_pos = (word_ts_stt_end[0] + word_ts_stt_end[1]) / 2 if word_pos < float(end_point): string_out = self.print_word(string_out, words[j], self.params) else: idx += 1 idx = min(idx, len(labels) - 1) start_point, end_point, speaker = labels[idx].split() string_out = self.print_time(string_out, speaker, start_point, end_point, self.params) string_out = self.print_word(string_out, words[j], self.params) stt_sec, end_sec = round(word_ts_stt_end[0], 2), round(word_ts_stt_end[1], 2) riva_dict = self.add_json_to_dict(riva_dict, words[j], stt_sec, end_sec, speaker) total_riva_dict[uniq_id] = riva_dict audacity_label_words = self.get_audacity_label( words[j], stt_sec, end_sec, speaker, audacity_label_words ) self.write_and_log(uniq_id, riva_dict, string_out, audacity_label_words) return total_riva_dict
def get_WDER(self, total_riva_dict, DER_result_dict): """ Calculate word-level diarization error rate (WDER). WDER is calculated by counting the wrongly diarized words and divided by the total number of words recognized by the ASR model. Args: total_riva_dict (dict): Dictionary that stores riva_dict(dict) which is indexed by uniq_id variable. DER_result_dict (dict): Dictionary that stores DER, FA, Miss, CER, mapping, the estimated number of speakers and speaker counting accuracy. Returns: wder_dict (dict): A dictionary containing WDER value for each session and total WDER. """ wder_dict, count_dict = {'session_level': {}}, {} asr_eval_dict = {'hypotheses_list': [], 'references_list': []} align_error_list = [] count_dict['total_ctm_wder_count'], count_dict['total_asr_and_spk_correct_words'] = 0, 0 ( count_dict['grand_total_ctm_word_count'], count_dict['grand_total_pred_word_count'], count_dict['grand_total_correct_word_count'], ) = (0, 0, 0) if any([self.AUDIO_RTTM_MAP[uniq_id]['ctm_filepath'] != None for uniq_id in self.AUDIO_RTTM_MAP.keys()]): if not DIFF_MATCH_PATCH: raise ImportError( 'CTM file is provided but diff_match_patch is not installed. Install diff_match_patch using PyPI: pip install diff_match_patch' ) for k, audio_file_path in enumerate(self.audio_file_list): uniq_id = get_uniqname_from_filepath(audio_file_path) error_dict = {'uniq_id': uniq_id} ref_rttm = self.AUDIO_RTTM_MAP[uniq_id]['rttm_filepath'] ref_labels = rttm_to_labels(ref_rttm) mapping_dict = DER_result_dict[uniq_id]['mapping'] hyp_w_dict_list = total_riva_dict[uniq_id]['words'] hyp_w_dict_list, word_seq_list, correct_word_count, rttm_wder = self.calculate_WDER_from_RTTM( hyp_w_dict_list, ref_labels, mapping_dict ) error_dict['rttm_based_wder'] = rttm_wder error_dict.update(DER_result_dict[uniq_id]) # If CTM files are provided, evaluate word-level diarization and WER with the CTM files. if self.AUDIO_RTTM_MAP[uniq_id]['ctm_filepath']: self.ctm_exists[uniq_id] = True ctm_content = open(self.AUDIO_RTTM_MAP[uniq_id]['ctm_filepath']).readlines() self.get_ctm_based_eval(ctm_content, error_dict, count_dict, hyp_w_dict_list, mapping_dict) else: self.ctm_exists[uniq_id] = False wder_dict['session_level'][uniq_id] = error_dict asr_eval_dict['hypotheses_list'].append(' '.join(word_seq_list)) asr_eval_dict['references_list'].append(self.AUDIO_RTTM_MAP[uniq_id]['text']) count_dict['grand_total_pred_word_count'] += len(hyp_w_dict_list) count_dict['grand_total_correct_word_count'] += correct_word_count wder_dict = self.get_wder_dict_values(asr_eval_dict, wder_dict, count_dict, align_error_list) return wder_dict
def gather_eval_results(self, metric, mapping_dict, total_riva_dict): """ Gather diarization evaluation results from pyannote DiarizationErrorRate metric object. Args: metric (DiarizationErrorRate metric): DiarizationErrorRate metric pyannote object mapping_dict (dict): A dictionary containing speaker mapping labels for each audio file with key as unique name Returns: DER_result_dict (dict): A dictionary containing scores for each audio file along with aggregated results """ results = metric.results_ DER_result_dict = {} count_correct_spk_counting = 0 for result in results: key, score = result pred_rttm = os.path.join(self.root_path, 'pred_rttms', key + '.rttm') pred_labels = rttm_to_labels(pred_rttm) est_n_spk = self.get_num_of_spk_from_labels(pred_labels) ref_rttm = self.AUDIO_RTTM_MAP[key]['rttm_filepath'] ref_labels = rttm_to_labels(ref_rttm) ref_n_spk = self.get_num_of_spk_from_labels(ref_labels) if self.cfg_diarizer['oracle_vad']: score['missed detection'] = 0 score['false alarm'] = 0 _DER, _CER, _FA, _MISS = ( (score['confusion'] + score['false alarm'] + score['missed detection']) / score['total'], score['confusion'] / score['total'], score['false alarm'] / score['total'], score['missed detection'] / score['total'], ) DER_result_dict[key] = { "DER": round(_DER, 4), "CER": round(_CER, 4), "FA": round(_FA, 4), "MISS": round(_MISS, 4), "est_n_spk": est_n_spk, "mapping": mapping_dict[key], "is_spk_count_correct": (est_n_spk == ref_n_spk), } count_correct_spk_counting += int(est_n_spk == ref_n_spk) DER, CER, FA, MISS = ( abs(metric), metric['confusion'] / metric['total'], metric['false alarm'] / metric['total'], metric['missed detection'] / metric['total'], ) DER_result_dict["total"] = { "DER": DER, "CER": CER, "FA": FA, "MISS": MISS, "spk_counting_acc": count_correct_spk_counting / len(metric.results_), } return DER_result_dict
def main(wav_scp, text_scp=None, rttm_scp=None, uem_scp=None, ctm_scp=None, manifest_filepath=None): if os.path.exists(manifest_filepath): os.remove(manifest_filepath) wav_scp = read_file(wav_scp) len_wavs = len(wav_scp) if text_scp is not None: text_scp = read_file(text_scp) assert len(text_scp) == len_wavs else: text_scp = len(wav_scp) * [None] if rttm_scp is not None: rttm_scp = read_file(rttm_scp) assert len(rttm_scp) == len_wavs else: rttm_scp = len(wav_scp) * [None] if uem_scp is not None: uem_scp = read_file(uem_scp) assert len(uem_scp) == len_wavs else: uem_scp = len(wav_scp) * [None] if ctm_scp is not None: ctm_scp = read_file(ctm_scp) else: ctm_scp = len(wav_scp) * [None] lines = [] for wav, text, rttm, uem, ctm in zip(wav_scp, text_scp, rttm_scp, uem_scp, ctm_scp): audio_line = wav.strip() if rttm is not None: rttm = rttm.strip() labels = rttm_to_labels(rttm) num_speakers = Counter([l.split()[-1] for l in labels]).keys().__len__() else: num_speakers = None if uem is not None: uem = uem.strip() if text is not None: text = open(text.strip()).readlines()[0].strip() else: text = "-" if ctm is not None: ctm = ctm.strip() meta = [{ "audio_filepath": audio_line, "offset": 0, "duration": None, "label": "infer", "text": text, "num_speakers": num_speakers, "rttm_filepath": rttm, "uem_filepath": uem, "ctm_filepath": ctm, }] lines.extend(meta) write_file(manifest_filepath, lines, range(len(lines)))
def eval_diarization(self, audio_file_list, ref_rttm_file_list): """ Evaluate the predicted speaker labels (pred_rttm) using ref_rttm_file_list. DER and speaker counting accuracy are calculated. Args: audio_file_list (list): The list of audio file paths. ref_rttm_file_list (list): The list of refrence rttm paths. """ diar_labels, ref_labels_list = [], [] all_hypotheses, all_references = [], [] DER_result_dict = {} count_correct_spk_counting = 0 audio_rttm_map = get_audio_rttm_map(audio_file_list, ref_rttm_file_list) for k, audio_file_path in enumerate(audio_file_list): uniq_id = get_uniq_id_from_audio_path(audio_file_path) rttm_file = audio_rttm_map[uniq_id]['rttm_path'] if os.path.exists(rttm_file): ref_labels = rttm_to_labels(rttm_file) ref_labels_list.append(ref_labels) reference = labels_to_pyannote_object(ref_labels) all_references.append(reference) else: raise ValueError("No reference RTTM file provided.") pred_rttm = os.path.join(self.oracle_vad_dir, 'pred_rttms', uniq_id + '.rttm') pred_labels = rttm_to_labels(pred_rttm) diar_labels.append(pred_labels) est_n_spk = self.get_num_of_spk_from_labels(pred_labels) ref_n_spk = self.get_num_of_spk_from_labels(ref_labels) hypothesis = labels_to_pyannote_object(pred_labels) all_hypotheses.append(hypothesis) DER, CER, FA, MISS, mapping = get_DER([reference], [hypothesis]) DER_result_dict[uniq_id] = { "DER": DER, "CER": CER, "FA": FA, "MISS": MISS, "n_spk": est_n_spk, "mapping": mapping[0], "spk_counting": (est_n_spk == ref_n_spk), } count_correct_spk_counting += int(est_n_spk == ref_n_spk) DER, CER, FA, MISS, mapping = get_DER(all_references, all_hypotheses) logging.info( "Cumulative results of all the files: \n FA: {:.4f}\t MISS {:.4f}\t\ Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format( FA, MISS, DER, CER)) DER_result_dict['total'] = { "DER": DER, "CER": CER, "FA": FA, "MISS": MISS, "spk_counting_acc": count_correct_spk_counting / len(audio_file_list), } return diar_labels, ref_labels_list, DER_result_dict