Exemplo n.º 1
0
    def gather_eval_results(self, metric, mapping_dict):
        """
        Gathers diarization evaluation results from pyannote DiarizationErrorRate metric object.
        Inputs
        metric (DiarizationErrorRate metric): DiarizationErrorRate metric pyannote object 
        mapping_dict (dict): Dictionary containing speaker mapping labels for each audio file with key as uniq name
        Returns
        DER_result_dict (dict): Dictionary containing scores for each audio file along with aggreated results
        """
        results = metric.results_
        DER_result_dict = {}
        count_correct_spk_counting = 0
        for result in results:
            key, score = result
            pred_rttm = os.path.join(self.root_path, 'pred_rttms',
                                     key + '.rttm')
            pred_labels = rttm_to_labels(pred_rttm)

            est_n_spk = self.get_num_of_spk_from_labels(pred_labels)
            ref_rttm = self.AUDIO_RTTM_MAP[key]['rttm_filepath']
            ref_labels = rttm_to_labels(ref_rttm)
            ref_n_spk = self.get_num_of_spk_from_labels(ref_labels)
            DER, CER, FA, MISS = (
                score['diarization error rate'],
                score['confusion'],
                score['false alarm'],
                score['missed detection'],
            )
            DER_result_dict[key] = {
                "DER": DER,
                "CER": CER,
                "FA": FA,
                "MISS": MISS,
                "n_spk": est_n_spk,
                "mapping": mapping_dict[key],
                "spk_counting": (est_n_spk == ref_n_spk),
            }
            logging.info("score for session {}: {}".format(
                key, DER_result_dict[key]))
            count_correct_spk_counting += int(est_n_spk == ref_n_spk)

        DER, CER, FA, MISS = (
            abs(metric),
            metric['confusion'] / metric['total'],
            metric['false alarm'] / metric['total'],
            metric['missed detection'] / metric['total'],
        )
        DER_result_dict["total"] = {
            "DER": DER,
            "CER": CER,
            "FA": FA,
            "MISS": MISS,
            "spk_counting_acc":
            count_correct_spk_counting / len(metric.results_),
        }

        return DER_result_dict
Exemplo n.º 2
0
    def get_diarization_labels(self, audio_file_list):
        """
        Save the diarization labels into a list.

        Arg:
            audio_file_list (list):
                The list of audio file paths.
        """
        diar_labels = []
        for k, audio_file_path in enumerate(audio_file_list):
            uniq_id = get_uniq_id_from_audio_path(audio_file_path)
            pred_rttm = os.path.join(self.oracle_vad_dir, 'pred_rttms',
                                     uniq_id + '.rttm')
            pred_labels = rttm_to_labels(pred_rttm)
            diar_labels.append(pred_labels)
            est_n_spk = self.get_num_of_spk_from_labels(pred_labels)
            logging.info(f"Estimated n_spk [{uniq_id}]: {est_n_spk}")

        return diar_labels
Exemplo n.º 3
0
    def run_diarization(
        self,
        diar_model_config,
        word_timestamps,
    ):
        """
        Launch the diarization process using the given VAD timestamp (oracle_manifest).

        Args:
            word_and_timestamps (list):
                List containing words and word timestamps

        Returns:
            diar_hyp (dict):
                A dictionary containing rttm results which are indexed by a unique ID.
            score Tuple[pyannote object, dict]:
                A tuple containing pyannote metric instance and mapping dictionary between
                speakers in hypotheses and speakers in reference RTTM files.
        """

        if diar_model_config.diarizer.asr.parameters.asr_based_vad:
            self.save_VAD_labels_list(word_timestamps)
            oracle_manifest = os.path.join(self.root_path,
                                           'asr_vad_manifest.json')
            oracle_manifest = write_rttm2manifest(self.VAD_RTTM_MAP,
                                                  oracle_manifest)
            diar_model_config.diarizer.vad.model_path = None
            diar_model_config.diarizer.vad.external_vad_manifest = oracle_manifest

        diar_model = ClusteringDiarizer(cfg=diar_model_config)
        score = diar_model.diarize()
        if diar_model_config.diarizer.vad.model_path is not None and not diar_model_config.diarizer.oracle_vad:
            self.get_frame_level_VAD(
                vad_processing_dir=diar_model.vad_pred_dir)

        diar_hyp = {}
        for k, audio_file_path in enumerate(self.audio_file_list):
            uniq_id = get_uniqname_from_filepath(audio_file_path)
            pred_rttm = os.path.join(self.root_path, 'pred_rttms',
                                     uniq_id + '.rttm')
            diar_hyp[uniq_id] = rttm_to_labels(pred_rttm)
        return diar_hyp, score
Exemplo n.º 4
0
def main(wav_path,
         text_path=None,
         rttm_path=None,
         uem_path=None,
         ctm_path=None,
         manifest_filepath=None):
    if os.path.exists(manifest_filepath):
        os.remove(manifest_filepath)

    wav_pathlist = read_file(wav_path)
    wav_pathdict = get_dict_from_wavlist(wav_pathlist)
    len_wavs = len(wav_pathlist)
    uniqids = sorted(wav_pathdict.keys())

    text_pathdict = get_path_dict(text_path, uniqids, len_wavs)
    rttm_pathdict = get_path_dict(rttm_path, uniqids, len_wavs)
    uem_pathdict = get_path_dict(uem_path, uniqids, len_wavs)
    ctm_pathdict = get_path_dict(ctm_path, uniqids, len_wavs)

    lines = []
    for uid in uniqids:
        wav, text, rttm, uem, ctm = (
            wav_pathdict[uid],
            text_pathdict[uid],
            rttm_pathdict[uid],
            uem_pathdict[uid],
            ctm_pathdict[uid],
        )

        audio_line = wav.strip()
        if rttm is not None:
            rttm = rttm.strip()
            labels = rttm_to_labels(rttm)
            num_speakers = Counter([l.split()[-1]
                                    for l in labels]).keys().__len__()
        else:
            num_speakers = None

        if uem is not None:
            uem = uem.strip()

        if text is not None:
            text = open(text.strip()).readlines()[0].strip()
        else:
            text = "-"

        if ctm is not None:
            ctm = ctm.strip()

        meta = [{
            "audio_filepath": audio_line,
            "offset": 0,
            "duration": None,
            "label": "infer",
            "text": text,
            "num_speakers": num_speakers,
            "rttm_filepath": rttm,
            "uem_filepath": uem,
            "ctm_filepath": ctm,
        }]
        lines.extend(meta)

    write_file(manifest_filepath, lines, range(len(lines)))
Exemplo n.º 5
0
    def get_WDER(self, total_riva_dict, DER_result_dict):
        """
        Calculate word-level diarization error rate (WDER). WDER is calculated by
        counting the the wrongly diarized words and divided by the total number of words
        recognized by the ASR model.

        Args:
            total_riva_dict (dict):
                The dictionary that stores riva_dict(dict) indexed by uniq_id variable.
            DER_result_dict (dict):
                The dictionary that stores DER, FA, Miss, CER, mapping, the estimated
                number of speakers and speaker counting accuracy.
            ref_labels_list (list):
                List that contains the ground truth speaker labels for each segment.

        Return:
            wder_dict (dict):
                A dictionary contains WDER value for each session and total WDER.
        """
        wder_dict = {}
        grand_total_word_count, grand_correct_word_count = 0, 0
        for k, audio_file_path in enumerate(self.audio_file_list):

            uniq_id = get_uniqname_from_filepath(audio_file_path)
            ref_rttm = self.AUDIO_RTTM_MAP[uniq_id]['rttm_filepath']
            labels = rttm_to_labels(ref_rttm)
            mapping_dict = DER_result_dict[uniq_id]['mapping']
            words_list = total_riva_dict[uniq_id]['words']

            idx, correct_word_count = 0, 0
            total_word_count = len(words_list)
            ref_label_list = [[float(x.split()[0]), float(x.split()[1])] for x in labels]
            ref_label_array = np.array(ref_label_list)

            for wdict in words_list:
                speaker_label = wdict['speaker_label']
                if speaker_label in mapping_dict:
                    est_spk_label = mapping_dict[speaker_label]
                else:
                    continue
                start_point, end_point, ref_spk_label = labels[idx].split()
                word_range = np.array([wdict['start_time'], wdict['end_time']])
                word_range_tile = np.tile(word_range, (ref_label_array.shape[0], 1))
                ovl_bool = self.isOverlapArray(ref_label_array, word_range_tile)
                if np.any(ovl_bool) == False:
                    continue

                ovl_length = self.getOverlapRangeArray(ref_label_array, word_range_tile)

                if self.params['lenient_overlap_WDER']:
                    ovl_length_list = list(ovl_length[ovl_bool])
                    max_ovl_sub_idx = np.where(ovl_length_list == np.max(ovl_length_list))[0]
                    max_ovl_idx = np.where(ovl_bool == True)[0][max_ovl_sub_idx]
                    ref_spk_labels = [x.split()[-1] for x in list(np.array(labels)[max_ovl_idx])]
                    if est_spk_label in ref_spk_labels:
                        correct_word_count += 1
                else:
                    max_ovl_sub_idx = np.argmax(ovl_length[ovl_bool])
                    max_ovl_idx = np.where(ovl_bool == True)[0][max_ovl_sub_idx]
                    _, _, ref_spk_label = labels[max_ovl_idx].split()
                    correct_word_count += int(est_spk_label == ref_spk_label)

            wder = 1 - (correct_word_count / total_word_count)
            grand_total_word_count += total_word_count
            grand_correct_word_count += correct_word_count

            wder_dict[uniq_id] = wder

        wder_dict['total'] = 1 - (grand_correct_word_count / grand_total_word_count)
        return wder_dict
Exemplo n.º 6
0
    def write_json_and_transcript(
        self, word_list, word_ts_list,
    ):
        """
        Matches the diarization result with the ASR output.
        The words and the timestamps for the corresponding words are matched
        in a for loop.

        Args:
            diar_labels (list):
                List of the Diarization output labels in str.
            word_list (list):
                List of words from ASR inference.
            word_ts_list (list):
                Contains word_ts_stt_end lists.
                word_ts_stt_end = [stt, end]
                    stt: Start of the word in sec.
                    end: End of the word in sec.

        Return:
            total_riva_dict (dict):
                A dictionary contains word timestamps, speaker labels and words.

        """
        total_riva_dict = {}
        if self.fix_word_ts_with_VAD:
            word_ts_list = self.compensate_word_ts_list(self.audio_file_list, word_ts_list, self.params)
            if self.frame_VAD == {}:
                logging.info(
                    f"VAD timestamps are not provided and skipping word timestamp fix. Please check the VAD model."
                )

        for k, audio_file_path in enumerate(self.audio_file_list):
            uniq_id = get_uniqname_from_filepath(audio_file_path)
            pred_rttm = os.path.join(self.root_path, 'pred_rttms', uniq_id + '.rttm')
            labels = rttm_to_labels(pred_rttm)
            audacity_label_words = []
            n_spk = self.get_num_of_spk_from_labels(labels)
            string_out = ''
            riva_dict = od(
                {
                    'status': 'Success',
                    'session_id': uniq_id,
                    'transcription': ' '.join(word_list[k]),
                    'speaker_count': n_spk,
                    'words': [],
                }
            )

            start_point, end_point, speaker = labels[0].split()
            words = word_list[k]

            logging.info(f"Creating results for Session: {uniq_id} n_spk: {n_spk} ")
            string_out = self.print_time(string_out, speaker, start_point, end_point, self.params)
            word_pos, idx = 0, 0
            for j, word_ts_stt_end in enumerate(word_ts_list[k]):

                word_pos = (word_ts_stt_end[0] + word_ts_stt_end[1]) / 2
                if word_pos < float(end_point):
                    string_out = self.print_word(string_out, words[j], self.params)
                else:
                    idx += 1
                    idx = min(idx, len(labels) - 1)
                    start_point, end_point, speaker = labels[idx].split()
                    string_out = self.print_time(string_out, speaker, start_point, end_point, self.params)
                    string_out = self.print_word(string_out, words[j], self.params)

                stt_sec, end_sec = round(word_ts_stt_end[0], 2), round(word_ts_stt_end[1], 2)
                riva_dict = self.add_json_to_dict(riva_dict, words[j], stt_sec, end_sec, speaker)

                total_riva_dict[uniq_id] = riva_dict
                audacity_label_words = self.get_audacity_label(
                    words[j], stt_sec, end_sec, speaker, audacity_label_words
                )

            self.write_and_log(uniq_id, riva_dict, string_out, audacity_label_words)

        return total_riva_dict
Exemplo n.º 7
0
    def get_WDER(self, total_riva_dict, DER_result_dict):
        """
        Calculate word-level diarization error rate (WDER). WDER is calculated by
        counting the wrongly diarized words and divided by the total number of words
        recognized by the ASR model.

        Args:
            total_riva_dict (dict):
                Dictionary that stores riva_dict(dict) which is indexed by uniq_id variable.
            DER_result_dict (dict):
                Dictionary that stores DER, FA, Miss, CER, mapping, the estimated
                number of speakers and speaker counting accuracy.

        Returns:
            wder_dict (dict):
                A dictionary containing  WDER value for each session and total WDER.
        """
        wder_dict, count_dict = {'session_level': {}}, {}
        asr_eval_dict = {'hypotheses_list': [], 'references_list': []}
        align_error_list = []

        count_dict['total_ctm_wder_count'], count_dict['total_asr_and_spk_correct_words'] = 0, 0
        (
            count_dict['grand_total_ctm_word_count'],
            count_dict['grand_total_pred_word_count'],
            count_dict['grand_total_correct_word_count'],
        ) = (0, 0, 0)

        if any([self.AUDIO_RTTM_MAP[uniq_id]['ctm_filepath'] != None for uniq_id in self.AUDIO_RTTM_MAP.keys()]):
            if not DIFF_MATCH_PATCH:
                raise ImportError(
                    'CTM file is provided but diff_match_patch is not installed. Install diff_match_patch using PyPI: pip install diff_match_patch'
                )

        for k, audio_file_path in enumerate(self.audio_file_list):

            uniq_id = get_uniqname_from_filepath(audio_file_path)
            error_dict = {'uniq_id': uniq_id}
            ref_rttm = self.AUDIO_RTTM_MAP[uniq_id]['rttm_filepath']
            ref_labels = rttm_to_labels(ref_rttm)
            mapping_dict = DER_result_dict[uniq_id]['mapping']
            hyp_w_dict_list = total_riva_dict[uniq_id]['words']
            hyp_w_dict_list, word_seq_list, correct_word_count, rttm_wder = self.calculate_WDER_from_RTTM(
                hyp_w_dict_list, ref_labels, mapping_dict
            )
            error_dict['rttm_based_wder'] = rttm_wder
            error_dict.update(DER_result_dict[uniq_id])

            # If CTM files are provided, evaluate word-level diarization and WER with the CTM files.
            if self.AUDIO_RTTM_MAP[uniq_id]['ctm_filepath']:
                self.ctm_exists[uniq_id] = True
                ctm_content = open(self.AUDIO_RTTM_MAP[uniq_id]['ctm_filepath']).readlines()
                self.get_ctm_based_eval(ctm_content, error_dict, count_dict, hyp_w_dict_list, mapping_dict)
            else:
                self.ctm_exists[uniq_id] = False

            wder_dict['session_level'][uniq_id] = error_dict
            asr_eval_dict['hypotheses_list'].append(' '.join(word_seq_list))
            asr_eval_dict['references_list'].append(self.AUDIO_RTTM_MAP[uniq_id]['text'])

            count_dict['grand_total_pred_word_count'] += len(hyp_w_dict_list)
            count_dict['grand_total_correct_word_count'] += correct_word_count

        wder_dict = self.get_wder_dict_values(asr_eval_dict, wder_dict, count_dict, align_error_list)
        return wder_dict
Exemplo n.º 8
0
    def gather_eval_results(self, metric, mapping_dict, total_riva_dict):
        """
        Gather diarization evaluation results from pyannote DiarizationErrorRate metric object.

        Args:
            metric (DiarizationErrorRate metric): DiarizationErrorRate metric pyannote object
            mapping_dict (dict): A dictionary containing speaker mapping labels for each audio file with key as unique name

        Returns:
            DER_result_dict (dict): A dictionary containing scores for each audio file along with aggregated results
        """
        results = metric.results_
        DER_result_dict = {}
        count_correct_spk_counting = 0
        for result in results:
            key, score = result
            pred_rttm = os.path.join(self.root_path, 'pred_rttms', key + '.rttm')
            pred_labels = rttm_to_labels(pred_rttm)

            est_n_spk = self.get_num_of_spk_from_labels(pred_labels)
            ref_rttm = self.AUDIO_RTTM_MAP[key]['rttm_filepath']
            ref_labels = rttm_to_labels(ref_rttm)
            ref_n_spk = self.get_num_of_spk_from_labels(ref_labels)

            if self.cfg_diarizer['oracle_vad']:
                score['missed detection'] = 0
                score['false alarm'] = 0

            _DER, _CER, _FA, _MISS = (
                (score['confusion'] + score['false alarm'] + score['missed detection']) / score['total'],
                score['confusion'] / score['total'],
                score['false alarm'] / score['total'],
                score['missed detection'] / score['total'],
            )

            DER_result_dict[key] = {
                "DER": round(_DER, 4),
                "CER": round(_CER, 4),
                "FA": round(_FA, 4),
                "MISS": round(_MISS, 4),
                "est_n_spk": est_n_spk,
                "mapping": mapping_dict[key],
                "is_spk_count_correct": (est_n_spk == ref_n_spk),
            }
            count_correct_spk_counting += int(est_n_spk == ref_n_spk)

        DER, CER, FA, MISS = (
            abs(metric),
            metric['confusion'] / metric['total'],
            metric['false alarm'] / metric['total'],
            metric['missed detection'] / metric['total'],
        )
        DER_result_dict["total"] = {
            "DER": DER,
            "CER": CER,
            "FA": FA,
            "MISS": MISS,
            "spk_counting_acc": count_correct_spk_counting / len(metric.results_),
        }

        return DER_result_dict
Exemplo n.º 9
0
def main(wav_scp,
         text_scp=None,
         rttm_scp=None,
         uem_scp=None,
         ctm_scp=None,
         manifest_filepath=None):
    if os.path.exists(manifest_filepath):
        os.remove(manifest_filepath)

    wav_scp = read_file(wav_scp)
    len_wavs = len(wav_scp)

    if text_scp is not None:
        text_scp = read_file(text_scp)
        assert len(text_scp) == len_wavs
    else:
        text_scp = len(wav_scp) * [None]

    if rttm_scp is not None:
        rttm_scp = read_file(rttm_scp)
        assert len(rttm_scp) == len_wavs
    else:
        rttm_scp = len(wav_scp) * [None]

    if uem_scp is not None:
        uem_scp = read_file(uem_scp)
        assert len(uem_scp) == len_wavs
    else:
        uem_scp = len(wav_scp) * [None]

    if ctm_scp is not None:
        ctm_scp = read_file(ctm_scp)
    else:
        ctm_scp = len(wav_scp) * [None]

    lines = []
    for wav, text, rttm, uem, ctm in zip(wav_scp, text_scp, rttm_scp, uem_scp,
                                         ctm_scp):
        audio_line = wav.strip()
        if rttm is not None:
            rttm = rttm.strip()
            labels = rttm_to_labels(rttm)
            num_speakers = Counter([l.split()[-1]
                                    for l in labels]).keys().__len__()
        else:
            num_speakers = None

        if uem is not None:
            uem = uem.strip()

        if text is not None:
            text = open(text.strip()).readlines()[0].strip()
        else:
            text = "-"

        if ctm is not None:
            ctm = ctm.strip()

        meta = [{
            "audio_filepath": audio_line,
            "offset": 0,
            "duration": None,
            "label": "infer",
            "text": text,
            "num_speakers": num_speakers,
            "rttm_filepath": rttm,
            "uem_filepath": uem,
            "ctm_filepath": ctm,
        }]
        lines.extend(meta)

    write_file(manifest_filepath, lines, range(len(lines)))
Exemplo n.º 10
0
    def eval_diarization(self, audio_file_list, ref_rttm_file_list):
        """
        Evaluate the predicted speaker labels (pred_rttm) using ref_rttm_file_list.
        DER and speaker counting accuracy are calculated.

        Args:
            audio_file_list (list):
                The list of audio file paths.
            ref_rttm_file_list (list):
                The list of refrence rttm paths.
        """
        diar_labels, ref_labels_list = [], []
        all_hypotheses, all_references = [], []
        DER_result_dict = {}
        count_correct_spk_counting = 0

        audio_rttm_map = get_audio_rttm_map(audio_file_list,
                                            ref_rttm_file_list)
        for k, audio_file_path in enumerate(audio_file_list):
            uniq_id = get_uniq_id_from_audio_path(audio_file_path)
            rttm_file = audio_rttm_map[uniq_id]['rttm_path']
            if os.path.exists(rttm_file):
                ref_labels = rttm_to_labels(rttm_file)
                ref_labels_list.append(ref_labels)
                reference = labels_to_pyannote_object(ref_labels)
                all_references.append(reference)
            else:
                raise ValueError("No reference RTTM file provided.")

            pred_rttm = os.path.join(self.oracle_vad_dir, 'pred_rttms',
                                     uniq_id + '.rttm')
            pred_labels = rttm_to_labels(pred_rttm)
            diar_labels.append(pred_labels)

            est_n_spk = self.get_num_of_spk_from_labels(pred_labels)
            ref_n_spk = self.get_num_of_spk_from_labels(ref_labels)
            hypothesis = labels_to_pyannote_object(pred_labels)
            all_hypotheses.append(hypothesis)
            DER, CER, FA, MISS, mapping = get_DER([reference], [hypothesis])
            DER_result_dict[uniq_id] = {
                "DER": DER,
                "CER": CER,
                "FA": FA,
                "MISS": MISS,
                "n_spk": est_n_spk,
                "mapping": mapping[0],
                "spk_counting": (est_n_spk == ref_n_spk),
            }
            count_correct_spk_counting += int(est_n_spk == ref_n_spk)

        DER, CER, FA, MISS, mapping = get_DER(all_references, all_hypotheses)
        logging.info(
            "Cumulative results of all the files:  \n FA: {:.4f}\t MISS {:.4f}\t\
                Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format(
                FA, MISS, DER, CER))
        DER_result_dict['total'] = {
            "DER": DER,
            "CER": CER,
            "FA": FA,
            "MISS": MISS,
            "spk_counting_acc":
            count_correct_spk_counting / len(audio_file_list),
        }
        return diar_labels, ref_labels_list, DER_result_dict