Example #1
0
    def _extract_embeddings(self, manifest_file: str):
        """
        This method extracts speaker embeddings from segments passed through manifest_file
        Optionally you may save the intermediate speaker embeddings for debugging or any use. 
        """
        logging.info("Extracting embeddings for Diarization")
        self._setup_spkr_test_data(manifest_file)
        self.embeddings = {}
        self._speaker_model = self._speaker_model.to(self._device)
        self._speaker_model.eval()
        self.time_stamps = {}

        all_embs = torch.empty([0])
        for test_batch in tqdm(self._speaker_model.test_dataloader()):
            test_batch = [x.to(self._device) for x in test_batch]
            audio_signal, audio_signal_len, labels, slices = test_batch
            with autocast():
                _, embs = self._speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
                emb_shape = embs.shape[-1]
                embs = embs.view(-1, emb_shape)
                all_embs = torch.cat((all_embs, embs.cpu().detach()), dim=0)
            del test_batch

        with open(manifest_file, 'r', encoding='utf-8') as manifest:
            for i, line in enumerate(manifest.readlines()):
                line = line.strip()
                dic = json.loads(line)
                uniq_name = get_uniqname_from_filepath(dic['audio_filepath'])
                if uniq_name in self.embeddings:
                    self.embeddings[uniq_name] = torch.cat((self.embeddings[uniq_name], all_embs[i].view(1, -1)))
                else:
                    self.embeddings[uniq_name] = all_embs[i].view(1, -1)
                if uniq_name not in self.time_stamps:
                    self.time_stamps[uniq_name] = []
                start = dic['offset']
                end = start + dic['duration']
                stamp = '{:.3f} {:.3f} '.format(start, end)
                self.time_stamps[uniq_name].append(stamp)

        if self._speaker_params.save_embeddings:
            embedding_dir = os.path.join(self._speaker_dir, 'embeddings')
            if not os.path.exists(embedding_dir):
                os.makedirs(embedding_dir, exist_ok=True)

            prefix = get_uniqname_from_filepath(manifest_file)
            name = os.path.join(embedding_dir, prefix)
            self._embeddings_file = name + f'_embeddings.pkl'
            pkl.dump(self.embeddings, open(self._embeddings_file, 'wb'))
            logging.info("Saved embedding files to {}".format(embedding_dir))
Example #2
0
    def save_VAD_labels_list(self, word_ts_list, audio_file_list):
        """
        Get non_speech labels from logit output. The logit output is obtained from
        run_ASR() function.

        Args:
            word_ts_list (list):
                List that contains word timestamps.
            audio_file_list (list):
                List of audio file paths.
        """
        self.VAD_RTTM_MAP = {}
        for i, word_timestamps in enumerate(word_ts_list):
            speech_labels_float = self._get_speech_labels_from_decoded_prediction(
                word_timestamps)
            speech_labels = self.get_str_speech_labels(speech_labels_float)
            uniq_id = get_uniqname_from_filepath(audio_file_list[i])
            output_path = os.path.join(self.root_path, 'pred_rttms')
            if not os.path.exists(output_path):
                os.makedirs(output_path)
            filename = labels_to_rttmfile(speech_labels, uniq_id, output_path)
            self.VAD_RTTM_MAP[uniq_id] = {
                'audio_filepath': audio_file_list[i],
                'rttm_filepath': filename
            }
Example #3
0
    def run_ASR_CitriNet_CTC(
            self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[Dict, Dict]:
        """
        Launch CitriNet ASR model and collect logit, timestamps and text output.

        Args:
            asr_model (class):
                The loaded NeMo ASR model.

        Returns:
            words_dict (dict):
                Dictionary of the sequence of words from hypothesis.
            word_ts_dict (dict):
                Dictionary of the timestamps of hypothesis words.
        """
        words_dict, word_ts_dict = {}, {}

        werbpe_ts = WERBPE_TS(
            tokenizer=asr_model.tokenizer,
            batch_dim_index=0,
            use_cer=asr_model._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=asr_model._cfg.get("log_prediction", False),
        )

        with torch.cuda.amp.autocast():
            transcript_logits_list = asr_model.transcribe(
                self.audio_file_list,
                batch_size=self.asr_batch_size,
                logprobs=True)
            for idx, logit_np in enumerate(transcript_logits_list):
                uniq_id = get_uniqname_from_filepath(self.audio_file_list[idx])
                if self.beam_search_decoder:
                    logging.info(
                        f"Running beam-search decoder with LM {self.ctc_decoder_params['pretrained_language_model']}"
                    )
                    hyp_words, word_ts = self.run_pyctcdecode(logit_np)
                else:
                    log_prob = torch.from_numpy(logit_np)
                    greedy_predictions = log_prob.argmax(
                        dim=-1, keepdim=False).unsqueeze(0)
                    logits_len = torch.from_numpy(np.array([log_prob.shape[0]
                                                            ]))
                    text, char_ts, word_ts = werbpe_ts.ctc_decoder_predictions_tensor_with_ts(
                        self.model_stride_in_secs,
                        greedy_predictions,
                        predictions_len=logits_len)
                    hyp_words, word_ts = text[0].split(), word_ts[0]
                word_ts = self.align_decoder_delay(word_ts,
                                                   self.decoder_delay_in_sec)
                assert len(hyp_words) == len(
                    word_ts
                ), "Words and word timestamp list length does not match."
                words_dict[uniq_id] = hyp_words
                word_ts_dict[uniq_id] = word_ts

        return words_dict, word_ts_dict
Example #4
0
    def get_transcript_with_speaker_labels(self, diar_hyp, word_hyp,
                                           word_ts_hyp):
        """
        Match the diarization result with the ASR output.
        The words and the timestamps for the corresponding words are matched
        in a for loop.

        Args:
            diar_labels (dict):
                Dictionary of the Diarization output labels in str.
            word_hyp (dict):
                Dictionary of words from ASR inference.
            word_ts_hyp (dict):
                Dictionary containing the start time and the end time of each word.

        Returns:
            total_riva_dict (dict):
                A dictionary containing word timestamps, speaker labels and words.

        """
        total_riva_dict = {}
        if self.fix_word_ts_with_VAD:
            if self.frame_VAD == {}:
                logging.info(
                    f"VAD timestamps are not provided. Fixing word timestamps without VAD. Please check the hydra configurations."
                )
            word_ts_refined = self.compensate_word_ts_list(
                self.audio_file_list, word_ts_hyp, self.params)
        else:
            word_ts_refined = word_ts_hyp

        if self.realigning_lm_params['arpa_language_model']:
            if not ARPA:
                raise ImportError(
                    'LM for realigning is provided but arpa is not installed. Install arpa using PyPI: pip install arpa'
                )
            else:
                self.realigning_lm = self.load_realigning_LM()

        for k, audio_file_path in enumerate(self.audio_file_list):
            uniq_id = get_uniqname_from_filepath(audio_file_path)
            word_dict_seq_list = self.get_word_dict_seq_list(
                uniq_id, diar_hyp, word_hyp, word_ts_hyp, word_ts_refined)
            if self.realigning_lm:
                word_dict_seq_list = self.realign_words_with_lm(
                    word_dict_seq_list)
            self.make_json_output(uniq_id, diar_hyp, word_dict_seq_list,
                                  total_riva_dict)
        logging.info(
            f"Diarization with ASR output files are saved in: {self.root_path}/pred_rttms"
        )
        return total_riva_dict
Example #5
0
    def compensate_word_ts_list(self, audio_file_list, word_ts_list, params):
        """
        Compensate the word timestamps based on the VAD output.
        The length of each word is capped by params['max_word_ts_length_in_sec'].

        Args:
            audio_file_list (list):
                List that contains audio file paths.
            word_ts_list (list):
                Contains word_ts_stt_end lists.
                word_ts_stt_end = [stt, end]
                    stt: Start of the word in sec.
                    end: End of the word in sec.
            params (dict):
                The parameter dictionary for diarization and ASR decoding.

        Return:
            enhanced_word_ts_list (list):
                List of the enhanced word timestamp values.
        """
        enhanced_word_ts_list = []
        for idx, word_ts_seq_list in enumerate(word_ts_list):
            uniq_id = get_uniqname_from_filepath(audio_file_list[idx])
            N = len(word_ts_seq_list)
            enhanced_word_ts_buffer = []
            for k, word_ts in enumerate(word_ts_seq_list):
                if k < N - 1:
                    word_len = round(word_ts[1] - word_ts[0], 2)
                    len_to_next_word = round(
                        word_ts_seq_list[k + 1][0] - word_ts[0] - 0.01, 2)
                    if uniq_id in self.frame_VAD:
                        vad_index_word_end = int(100 * word_ts[1])
                        closest_sil_stt = self.closest_silence_start(
                            vad_index_word_end, self.frame_VAD[uniq_id],
                            params)
                        vad_est_len = round(closest_sil_stt - word_ts[0], 2)
                    else:
                        vad_est_len = len_to_next_word
                    min_candidate = min(vad_est_len, len_to_next_word)
                    fixed_word_len = max(
                        min(params['max_word_ts_length_in_sec'],
                            min_candidate), word_len)
                    enhanced_word_ts_buffer.append(
                        [word_ts[0], word_ts[0] + fixed_word_len])
                else:
                    enhanced_word_ts_buffer.append([word_ts[0], word_ts[1]])

            enhanced_word_ts_list.append(enhanced_word_ts_buffer)
        return enhanced_word_ts_list
Example #6
0
    def run_diarization(
        self,
        diar_model_config,
        word_timestamps,
    ):
        """
        Launch the diarization process using the given VAD timestamp (oracle_manifest).

        Args:
            word_and_timestamps (list):
                List containing words and word timestamps

        Returns:
            diar_hyp (dict):
                A dictionary containing rttm results which are indexed by a unique ID.
            score Tuple[pyannote object, dict]:
                A tuple containing pyannote metric instance and mapping dictionary between
                speakers in hypotheses and speakers in reference RTTM files.
        """

        if diar_model_config.diarizer.asr.parameters.asr_based_vad:
            self.save_VAD_labels_list(word_timestamps)
            oracle_manifest = os.path.join(self.root_path,
                                           'asr_vad_manifest.json')
            oracle_manifest = write_rttm2manifest(self.VAD_RTTM_MAP,
                                                  oracle_manifest)
            diar_model_config.diarizer.vad.model_path = None
            diar_model_config.diarizer.vad.external_vad_manifest = oracle_manifest

        diar_model = ClusteringDiarizer(cfg=diar_model_config)
        score = diar_model.diarize()
        if diar_model_config.diarizer.vad.model_path is not None and not diar_model_config.diarizer.oracle_vad:
            self.get_frame_level_VAD(
                vad_processing_dir=diar_model.vad_pred_dir)

        diar_hyp = {}
        for k, audio_file_path in enumerate(self.audio_file_list):
            uniq_id = get_uniqname_from_filepath(audio_file_path)
            pred_rttm = os.path.join(self.root_path, 'pred_rttms',
                                     uniq_id + '.rttm')
            diar_hyp[uniq_id] = rttm_to_labels(pred_rttm)
        return diar_hyp, score
Example #7
0
    def get_WDER(self, total_riva_dict, DER_result_dict):
        """
        Calculate word-level diarization error rate (WDER). WDER is calculated by
        counting the the wrongly diarized words and divided by the total number of words
        recognized by the ASR model.

        Args:
            total_riva_dict (dict):
                The dictionary that stores riva_dict(dict) indexed by uniq_id variable.
            DER_result_dict (dict):
                The dictionary that stores DER, FA, Miss, CER, mapping, the estimated
                number of speakers and speaker counting accuracy.
            ref_labels_list (list):
                List that contains the ground truth speaker labels for each segment.

        Return:
            wder_dict (dict):
                A dictionary contains WDER value for each session and total WDER.
        """
        wder_dict = {}
        grand_total_word_count, grand_correct_word_count = 0, 0
        for k, audio_file_path in enumerate(self.audio_file_list):

            uniq_id = get_uniqname_from_filepath(audio_file_path)
            ref_rttm = self.AUDIO_RTTM_MAP[uniq_id]['rttm_filepath']
            labels = rttm_to_labels(ref_rttm)
            mapping_dict = DER_result_dict[uniq_id]['mapping']
            words_list = total_riva_dict[uniq_id]['words']

            idx, correct_word_count = 0, 0
            total_word_count = len(words_list)
            ref_label_list = [[float(x.split()[0]), float(x.split()[1])] for x in labels]
            ref_label_array = np.array(ref_label_list)

            for wdict in words_list:
                speaker_label = wdict['speaker_label']
                if speaker_label in mapping_dict:
                    est_spk_label = mapping_dict[speaker_label]
                else:
                    continue
                start_point, end_point, ref_spk_label = labels[idx].split()
                word_range = np.array([wdict['start_time'], wdict['end_time']])
                word_range_tile = np.tile(word_range, (ref_label_array.shape[0], 1))
                ovl_bool = self.isOverlapArray(ref_label_array, word_range_tile)
                if np.any(ovl_bool) == False:
                    continue

                ovl_length = self.getOverlapRangeArray(ref_label_array, word_range_tile)

                if self.params['lenient_overlap_WDER']:
                    ovl_length_list = list(ovl_length[ovl_bool])
                    max_ovl_sub_idx = np.where(ovl_length_list == np.max(ovl_length_list))[0]
                    max_ovl_idx = np.where(ovl_bool == True)[0][max_ovl_sub_idx]
                    ref_spk_labels = [x.split()[-1] for x in list(np.array(labels)[max_ovl_idx])]
                    if est_spk_label in ref_spk_labels:
                        correct_word_count += 1
                else:
                    max_ovl_sub_idx = np.argmax(ovl_length[ovl_bool])
                    max_ovl_idx = np.where(ovl_bool == True)[0][max_ovl_sub_idx]
                    _, _, ref_spk_label = labels[max_ovl_idx].split()
                    correct_word_count += int(est_spk_label == ref_spk_label)

            wder = 1 - (correct_word_count / total_word_count)
            grand_total_word_count += total_word_count
            grand_correct_word_count += correct_word_count

            wder_dict[uniq_id] = wder

        wder_dict['total'] = 1 - (grand_correct_word_count / grand_total_word_count)
        return wder_dict
Example #8
0
    def write_json_and_transcript(
        self, word_list, word_ts_list,
    ):
        """
        Matches the diarization result with the ASR output.
        The words and the timestamps for the corresponding words are matched
        in a for loop.

        Args:
            diar_labels (list):
                List of the Diarization output labels in str.
            word_list (list):
                List of words from ASR inference.
            word_ts_list (list):
                Contains word_ts_stt_end lists.
                word_ts_stt_end = [stt, end]
                    stt: Start of the word in sec.
                    end: End of the word in sec.

        Return:
            total_riva_dict (dict):
                A dictionary contains word timestamps, speaker labels and words.

        """
        total_riva_dict = {}
        if self.fix_word_ts_with_VAD:
            word_ts_list = self.compensate_word_ts_list(self.audio_file_list, word_ts_list, self.params)
            if self.frame_VAD == {}:
                logging.info(
                    f"VAD timestamps are not provided and skipping word timestamp fix. Please check the VAD model."
                )

        for k, audio_file_path in enumerate(self.audio_file_list):
            uniq_id = get_uniqname_from_filepath(audio_file_path)
            pred_rttm = os.path.join(self.root_path, 'pred_rttms', uniq_id + '.rttm')
            labels = rttm_to_labels(pred_rttm)
            audacity_label_words = []
            n_spk = self.get_num_of_spk_from_labels(labels)
            string_out = ''
            riva_dict = od(
                {
                    'status': 'Success',
                    'session_id': uniq_id,
                    'transcription': ' '.join(word_list[k]),
                    'speaker_count': n_spk,
                    'words': [],
                }
            )

            start_point, end_point, speaker = labels[0].split()
            words = word_list[k]

            logging.info(f"Creating results for Session: {uniq_id} n_spk: {n_spk} ")
            string_out = self.print_time(string_out, speaker, start_point, end_point, self.params)
            word_pos, idx = 0, 0
            for j, word_ts_stt_end in enumerate(word_ts_list[k]):

                word_pos = (word_ts_stt_end[0] + word_ts_stt_end[1]) / 2
                if word_pos < float(end_point):
                    string_out = self.print_word(string_out, words[j], self.params)
                else:
                    idx += 1
                    idx = min(idx, len(labels) - 1)
                    start_point, end_point, speaker = labels[idx].split()
                    string_out = self.print_time(string_out, speaker, start_point, end_point, self.params)
                    string_out = self.print_word(string_out, words[j], self.params)

                stt_sec, end_sec = round(word_ts_stt_end[0], 2), round(word_ts_stt_end[1], 2)
                riva_dict = self.add_json_to_dict(riva_dict, words[j], stt_sec, end_sec, speaker)

                total_riva_dict[uniq_id] = riva_dict
                audacity_label_words = self.get_audacity_label(
                    words[j], stt_sec, end_sec, speaker, audacity_label_words
                )

            self.write_and_log(uniq_id, riva_dict, string_out, audacity_label_words)

        return total_riva_dict
Example #9
0
    def get_WDER(self, total_riva_dict, DER_result_dict):
        """
        Calculate word-level diarization error rate (WDER). WDER is calculated by
        counting the wrongly diarized words and divided by the total number of words
        recognized by the ASR model.

        Args:
            total_riva_dict (dict):
                Dictionary that stores riva_dict(dict) which is indexed by uniq_id variable.
            DER_result_dict (dict):
                Dictionary that stores DER, FA, Miss, CER, mapping, the estimated
                number of speakers and speaker counting accuracy.

        Returns:
            wder_dict (dict):
                A dictionary containing  WDER value for each session and total WDER.
        """
        wder_dict, count_dict = {'session_level': {}}, {}
        asr_eval_dict = {'hypotheses_list': [], 'references_list': []}
        align_error_list = []

        count_dict['total_ctm_wder_count'], count_dict['total_asr_and_spk_correct_words'] = 0, 0
        (
            count_dict['grand_total_ctm_word_count'],
            count_dict['grand_total_pred_word_count'],
            count_dict['grand_total_correct_word_count'],
        ) = (0, 0, 0)

        if any([self.AUDIO_RTTM_MAP[uniq_id]['ctm_filepath'] != None for uniq_id in self.AUDIO_RTTM_MAP.keys()]):
            if not DIFF_MATCH_PATCH:
                raise ImportError(
                    'CTM file is provided but diff_match_patch is not installed. Install diff_match_patch using PyPI: pip install diff_match_patch'
                )

        for k, audio_file_path in enumerate(self.audio_file_list):

            uniq_id = get_uniqname_from_filepath(audio_file_path)
            error_dict = {'uniq_id': uniq_id}
            ref_rttm = self.AUDIO_RTTM_MAP[uniq_id]['rttm_filepath']
            ref_labels = rttm_to_labels(ref_rttm)
            mapping_dict = DER_result_dict[uniq_id]['mapping']
            hyp_w_dict_list = total_riva_dict[uniq_id]['words']
            hyp_w_dict_list, word_seq_list, correct_word_count, rttm_wder = self.calculate_WDER_from_RTTM(
                hyp_w_dict_list, ref_labels, mapping_dict
            )
            error_dict['rttm_based_wder'] = rttm_wder
            error_dict.update(DER_result_dict[uniq_id])

            # If CTM files are provided, evaluate word-level diarization and WER with the CTM files.
            if self.AUDIO_RTTM_MAP[uniq_id]['ctm_filepath']:
                self.ctm_exists[uniq_id] = True
                ctm_content = open(self.AUDIO_RTTM_MAP[uniq_id]['ctm_filepath']).readlines()
                self.get_ctm_based_eval(ctm_content, error_dict, count_dict, hyp_w_dict_list, mapping_dict)
            else:
                self.ctm_exists[uniq_id] = False

            wder_dict['session_level'][uniq_id] = error_dict
            asr_eval_dict['hypotheses_list'].append(' '.join(word_seq_list))
            asr_eval_dict['references_list'].append(self.AUDIO_RTTM_MAP[uniq_id]['text'])

            count_dict['grand_total_pred_word_count'] += len(hyp_w_dict_list)
            count_dict['grand_total_correct_word_count'] += correct_word_count

        wder_dict = self.get_wder_dict_values(asr_eval_dict, wder_dict, count_dict, align_error_list)
        return wder_dict
Example #10
0
    def run_ASR_BPE_CTC(
            self, asr_model: Type[EncDecCTCModelBPE]) -> Tuple[Dict, Dict]:
        """
        Launch CTC-BPE based ASR model and collect logit, timestamps and text output.

        Args:
            asr_model (class):
                The loaded NeMo ASR model.

        Returns:
            words_dict (dict):
                Dictionary of the sequence of words from hypothesis.
            word_ts_dict (dict):
                Dictionary of the time-stamps of words.
        """
        torch.manual_seed(0)
        torch.set_grad_enabled(False)
        words_dict, word_ts_dict = {}, {}

        werbpe_ts = WERBPE_TS(
            tokenizer=asr_model.tokenizer,
            batch_dim_index=0,
            use_cer=asr_model._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=asr_model._cfg.get("log_prediction", False),
        )

        frame_asr = FrameBatchASR_Logits(
            asr_model=asr_model,
            frame_len=self.chunk_len_in_sec,
            total_buffer=self.total_buffer_in_secs,
            batch_size=self.asr_batch_size,
        )

        onset_delay, mid_delay, tokens_per_chunk = self.set_buffered_infer_params(
            asr_model)
        onset_delay_in_sec = round(onset_delay * self.model_stride_in_secs, 2)

        with torch.cuda.amp.autocast():
            logging.info(f"Running ASR model {self.ASR_model_name}")

            for idx, audio_file_path in enumerate(self.audio_file_list):
                uniq_id = get_uniqname_from_filepath(audio_file_path)
                logging.info(
                    f"[{idx+1}/{len(self.audio_file_list)}] FrameBatchASR: {audio_file_path}"
                )
                frame_asr.clear_buffer()

                hyp, greedy_predictions_list, log_prob = get_wer_feat_logit(
                    audio_file_path,
                    frame_asr,
                    self.chunk_len_in_sec,
                    tokens_per_chunk,
                    mid_delay,
                    self.model_stride_in_secs,
                )
                if self.beam_search_decoder:
                    logging.info(
                        f"Running beam-search decoder with LM {self.ctc_decoder_params['pretrained_language_model']}"
                    )
                    log_prob = log_prob.unsqueeze(0).cpu().numpy()[0]
                    hyp_words, word_ts = self.run_pyctcdecode(
                        log_prob, onset_delay_in_sec=onset_delay_in_sec)
                else:
                    logits_len = torch.from_numpy(
                        np.array([len(greedy_predictions_list)]))
                    greedy_predictions_list = greedy_predictions_list[
                        onset_delay:]
                    greedy_predictions = torch.from_numpy(
                        np.array(greedy_predictions_list)).unsqueeze(0)
                    text, char_ts, word_ts = werbpe_ts.ctc_decoder_predictions_tensor_with_ts(
                        self.model_stride_in_secs,
                        greedy_predictions,
                        predictions_len=logits_len)
                    hyp_words, word_ts = text[0].split(), word_ts[0]

                word_ts = self.align_decoder_delay(word_ts,
                                                   self.decoder_delay_in_sec)
                assert len(hyp_words) == len(
                    word_ts
                ), "Words and word timestamp list length does not match."
                words_dict[uniq_id] = hyp_words
                word_ts_dict[uniq_id] = word_ts

        return words_dict, word_ts_dict
Example #11
0
    def _run_vad(self, manifest_file):
        """
        Run voice activity detection. 
        Get log probability of voice activity detection and smoothes using the post processing parameters. 
        Using generated frame level predictions generated manifest file for later speaker embedding extraction.
        input:
        manifest_file (str) : Manifest file containing path to audio file and label as infer

        """

        shutil.rmtree(self._vad_dir, ignore_errors=True)
        os.makedirs(self._vad_dir)

        self._vad_model = self._vad_model.to(self._device)
        self._vad_model.eval()

        time_unit = int(self._vad_window_length_in_sec /
                        self._vad_shift_length_in_sec)
        trunc = int(time_unit / 2)
        trunc_l = time_unit - trunc
        all_len = 0
        data = []
        for line in open(manifest_file, 'r'):
            file = json.loads(line)['audio_filepath']
            data.append(get_uniqname_from_filepath(file))

        status = get_vad_stream_status(data)
        for i, test_batch in enumerate(tqdm(
                self._vad_model.test_dataloader())):
            test_batch = [x.to(self._device) for x in test_batch]
            with autocast():
                log_probs = self._vad_model(input_signal=test_batch[0],
                                            input_signal_length=test_batch[1])
                probs = torch.softmax(log_probs, dim=-1)
                pred = probs[:, 1]
                if status[i] == 'start':
                    to_save = pred[:-trunc]
                elif status[i] == 'next':
                    to_save = pred[trunc:-trunc_l]
                elif status[i] == 'end':
                    to_save = pred[trunc_l:]
                else:
                    to_save = pred
                all_len += len(to_save)
                outpath = os.path.join(self._vad_dir, data[i] + ".frame")
                with open(outpath, "a") as fout:
                    for f in range(len(to_save)):
                        fout.write('{0:0.4f}\n'.format(to_save[f]))
            del test_batch
            if status[i] == 'end' or status[i] == 'single':
                all_len = 0

        if not self._vad_params.smoothing:
            # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
            self.vad_pred_dir = self._vad_dir
        else:
            # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
            # smoothing_method would be either in majority vote (median) or average (mean)
            logging.info(
                "Generating predictions with overlapping input segments")
            smoothing_pred_dir = generate_overlap_vad_seq(
                frame_pred_dir=self._vad_dir,
                smoothing_method=self._vad_params.smoothing,
                overlap=self._vad_params.overlap,
                seg_len=self._vad_window_length_in_sec,
                shift_len=self._vad_shift_length_in_sec,
                num_workers=self._cfg.num_workers,
            )
            self.vad_pred_dir = smoothing_pred_dir

        logging.info(
            "Converting frame level prediction to speech/no-speech segment in start and end times format."
        )

        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=self.vad_pred_dir,
            postprocessing_params=self._vad_params,
            shift_len=self._vad_shift_length_in_sec,
            num_workers=self._cfg.num_workers,
        )
        AUDIO_VAD_RTTM_MAP = deepcopy(self.AUDIO_RTTM_MAP.copy())
        for key in AUDIO_VAD_RTTM_MAP:
            AUDIO_VAD_RTTM_MAP[key]['rttm_filepath'] = os.path.join(
                table_out_dir, key + ".txt")

        write_rttm2manifest(AUDIO_VAD_RTTM_MAP, self._vad_out_file)
        self._speaker_manifest_path = self._vad_out_file