def _run_vad(self, manifest_file):
        self._vad_model = self._vad_model.to(self._device)
        self._vad_model.eval()

        time_unit = int(self._vad_window_length_in_sec /
                        self._vad_shift_length_in_sec)
        trunc = int(time_unit / 2)
        trunc_l = time_unit - trunc
        all_len = 0
        data = []
        for line in open(manifest_file, 'r'):
            file = os.path.basename(json.loads(line)['audio_filepath'])
            data.append(os.path.splitext(file)[0])

        status = get_vad_stream_status(data)
        for i, test_batch in enumerate(tqdm(
                self._vad_model.test_dataloader())):
            test_batch = [x.to(self._device) for x in test_batch]
            with autocast():
                log_probs = self._vad_model(input_signal=test_batch[0],
                                            input_signal_length=test_batch[1])
                probs = torch.softmax(log_probs, dim=-1)
                pred = probs[:, 1]
                if status[i] == 'start':
                    to_save = pred[:-trunc]
                elif status[i] == 'next':
                    to_save = pred[trunc:-trunc_l]
                elif status[i] == 'end':
                    to_save = pred[trunc_l:]
                else:
                    to_save = pred
                all_len += len(to_save)
                outpath = os.path.join(self._vad_dir, data[i] + ".frame")
                with open(outpath, "a") as fout:
                    for f in range(len(to_save)):
                        fout.write('{0:0.4f}\n'.format(to_save[f]))
            del test_batch
            if status[i] == 'end' or status[i] == 'single':
                all_len = 0

        if not self._cfg.diarizer.vad.vad_decision_smoothing:
            # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
            self.vad_pred_dir = self._vad_dir

        else:
            # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
            # smoothing_method would be either in majority vote (median) or average (mean)
            logging.info(
                "Generating predictions with overlapping input segments")
            smoothing_pred_dir = generate_overlap_vad_seq(
                frame_pred_dir=self._vad_dir,
                smoothing_method=self._cfg.diarizer.vad.smoothing_params.
                method,
                overlap=self._cfg.diarizer.vad.smoothing_params.overlap,
                seg_len=self._vad_window_length_in_sec,
                shift_len=self._vad_shift_length_in_sec,
                num_workers=self._cfg.num_workers,
            )
            self.vad_pred_dir = smoothing_pred_dir

        logging.info(
            "Converting frame level prediction to speech/no-speech segment in start and end times format."
        )
        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=self.vad_pred_dir,
            threshold=self._cfg.diarizer.vad.threshold,
            shift_len=self._vad_shift_length_in_sec,
            num_workers=self._cfg.num_workers,
        )

        vad_table_list = [
            os.path.join(table_out_dir, key + ".txt")
            for key in self.AUDIO_RTTM_MAP
        ]
        write_rttm2manifest(self._cfg.diarizer.paths2audio_files,
                            vad_table_list, self._vad_out_file)
        self._speaker_manifest_path = self._vad_out_file
    def _run_vad(self, manifest_file):
        """
        Run voice activity detection. 
        Get log probability of voice activity detection and smoothes using the post processing parameters. 
        Using generated frame level predictions generated manifest file for later speaker embedding extraction.
        input:
        manifest_file (str) : Manifest file containing path to audio file and label as infer

        """

        shutil.rmtree(self._vad_dir, ignore_errors=True)
        os.makedirs(self._vad_dir)

        self._vad_model = self._vad_model.to(self._device)
        self._vad_model.eval()

        time_unit = int(self._vad_window_length_in_sec /
                        self._vad_shift_length_in_sec)
        trunc = int(time_unit / 2)
        trunc_l = time_unit - trunc
        all_len = 0
        data = []
        for line in open(manifest_file, 'r'):
            file = json.loads(line)['audio_filepath']
            data.append(get_uniqname_from_filepath(file))

        status = get_vad_stream_status(data)
        for i, test_batch in enumerate(tqdm(
                self._vad_model.test_dataloader())):
            test_batch = [x.to(self._device) for x in test_batch]
            with autocast():
                log_probs = self._vad_model(input_signal=test_batch[0],
                                            input_signal_length=test_batch[1])
                probs = torch.softmax(log_probs, dim=-1)
                pred = probs[:, 1]
                if status[i] == 'start':
                    to_save = pred[:-trunc]
                elif status[i] == 'next':
                    to_save = pred[trunc:-trunc_l]
                elif status[i] == 'end':
                    to_save = pred[trunc_l:]
                else:
                    to_save = pred
                all_len += len(to_save)
                outpath = os.path.join(self._vad_dir, data[i] + ".frame")
                with open(outpath, "a") as fout:
                    for f in range(len(to_save)):
                        fout.write('{0:0.4f}\n'.format(to_save[f]))
            del test_batch
            if status[i] == 'end' or status[i] == 'single':
                all_len = 0

        if not self._vad_params.smoothing:
            # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
            self.vad_pred_dir = self._vad_dir
        else:
            # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
            # smoothing_method would be either in majority vote (median) or average (mean)
            logging.info(
                "Generating predictions with overlapping input segments")
            smoothing_pred_dir = generate_overlap_vad_seq(
                frame_pred_dir=self._vad_dir,
                smoothing_method=self._vad_params.smoothing,
                overlap=self._vad_params.overlap,
                seg_len=self._vad_window_length_in_sec,
                shift_len=self._vad_shift_length_in_sec,
                num_workers=self._cfg.num_workers,
            )
            self.vad_pred_dir = smoothing_pred_dir

        logging.info(
            "Converting frame level prediction to speech/no-speech segment in start and end times format."
        )

        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=self.vad_pred_dir,
            postprocessing_params=self._vad_params,
            shift_len=self._vad_shift_length_in_sec,
            num_workers=self._cfg.num_workers,
        )
        AUDIO_VAD_RTTM_MAP = deepcopy(self.AUDIO_RTTM_MAP.copy())
        for key in AUDIO_VAD_RTTM_MAP:
            AUDIO_VAD_RTTM_MAP[key]['rttm_filepath'] = os.path.join(
                table_out_dir, key + ".txt")

        write_rttm2manifest(AUDIO_VAD_RTTM_MAP, self._vad_out_file)
        self._speaker_manifest_path = self._vad_out_file
Exemple #3
0
def main(cfg):
    if not cfg.dataset:
        raise ValueError("You must input the path of json file of evaluation data")

    # each line of dataset should be have different audio_filepath and unique name to simplfiy edge cases or conditions
    key_meta_map = {}
    with open(cfg.dataset, 'r') as manifest:
        for line in manifest.readlines():
            audio_filepath = json.loads(line.strip())['audio_filepath']
            uniq_audio_name = audio_filepath.split('/')[-1].rsplit('.', 1)[0]
            if uniq_audio_name in key_meta_map:
                raise ValueError("Please make sure each line is with different audio_filepath! ")
            key_meta_map[uniq_audio_name] = {'audio_filepath': audio_filepath}

    # Prepare manifest for streaming VAD
    manifest_vad_input = cfg.dataset
    if cfg.prepare_manifest.auto_split:
        logging.info("Split long audio file to avoid CUDA memory issue")
        logging.debug("Try smaller split_duration if you still have CUDA memory issue")
        config = {
            'input': manifest_vad_input,
            'window_length_in_sec': cfg.vad.parameters.window_length_in_sec,
            'split_duration': cfg.prepare_manifest.split_duration,
            'num_workers': cfg.num_workers,
            'prepared_manfiest_vad_input': cfg.prepared_manfiest_vad_input,
        }
        manifest_vad_input = prepare_manifest(config)
    else:
        logging.warning(
            "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it."
        )

    torch.set_grad_enabled(False)
    vad_model = init_vad_model(cfg.vad.model_path)

    # setup_test_data
    vad_model.setup_test_data(
        test_data_config={
            'vad_stream': True,
            'sample_rate': 16000,
            'manifest_filepath': manifest_vad_input,
            'labels': ['infer',],
            'num_workers': cfg.num_workers,
            'shuffle': False,
            'window_length_in_sec': cfg.vad.parameters.window_length_in_sec,
            'shift_length_in_sec': cfg.vad.parameters.shift_length_in_sec,
            'trim_silence': False,
            'normalize_audio': cfg.vad.parameters.normalize_audio,
        }
    )

    vad_model = vad_model.to(device)
    vad_model.eval()

    if not os.path.exists(cfg.frame_out_dir):
        os.mkdir(cfg.frame_out_dir)
    else:
        logging.warning(
            "Note frame_out_dir exists. If new file has same name as file inside existing folder, it will append result to existing file and might cause mistakes for next steps."
        )

    logging.info("Generating frame level prediction ")
    pred_dir = generate_vad_frame_pred(
        vad_model=vad_model,
        window_length_in_sec=cfg.vad.parameters.window_length_in_sec,
        shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec,
        manifest_vad_input=manifest_vad_input,
        out_dir=cfg.frame_out_dir,
    )
    logging.info(
        f"Finish generating VAD frame level prediction with window_length_in_sec={cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={cfg.vad.parameters.shift_length_in_sec}"
    )

    # overlap smoothing filter
    if cfg.gen_overlap_seq:
        # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
        # smoothing_method would be either in majority vote (median) or average (mean)
        logging.info("Generating predictions with overlapping input segments")
        smoothing_pred_dir = generate_overlap_vad_seq(
            frame_pred_dir=pred_dir,
            smoothing_method=cfg.vad.parameters.smoothing,
            overlap=cfg.vad.parameters.overlap,
            window_length_in_sec=cfg.vad.parameters.window_length_in_sec,
            shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec,
            num_workers=cfg.num_workers,
            out_dir=cfg.smoothing_out_dir,
        )
        logging.info(
            f"Finish generating predictions with overlapping input segments with smoothing_method={cfg.vad.parameters.smoothing} and overlap={cfg.vad.parameters.overlap}"
        )
        pred_dir = smoothing_pred_dir

    # postprocessing and generate speech segments
    if cfg.gen_seg_table:
        logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")
        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=pred_dir,
            postprocessing_params=cfg.vad.parameters.postprocessing,
            shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec,
            num_workers=cfg.num_workers,
            out_dir=cfg.table_out_dir,
        )
        logging.info(
            f"Finish generating speech semgents table with postprocessing_params: {cfg.vad.parameters.postprocessing}"
        )

    if cfg.write_to_manifest:
        for i in key_meta_map:
            key_meta_map[i]['rttm_filepath'] = os.path.join(table_out_dir, i + ".txt")

        if not cfg.out_manifest_filepath:
            out_manifest_filepath = "vad_out.json"
        else:
            out_manifest_filepath = cfg.out_manifest_filepath
        out_manifest_filepath = write_rttm2manifest(key_meta_map, out_manifest_filepath)
        logging.info(f"Writing VAD output to manifest: {out_manifest_filepath}")
Exemple #4
0
            f"Finish generating predictions with overlapping input segments with smoothing_method={args.method} and overlap={args.overlap}"
        )
        end = time.time()
        logging.info(f"Generate overlapped prediction takes {end-start:.2f} seconds!\n Save to {overlap_out_dir}")

    if args.gen_seg_table:
        start = time.time()
        logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")

        if args.gen_overlap_seq:
            logging.info("Use overlap prediction. Change if you want to use basic frame level prediction")
            vad_pred_dir = overlap_out_dir
            shift_length_in_sec = 0.01
        else:
            logging.info("Use basic frame level prediction")
            vad_pred_dir = args.frame_folder
            shift_length_in_sec = args.shift_length_in_sec

        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=vad_pred_dir,
            postprocessing_params=postprocessing_params,
            shift_length_in_sec=args.shift_length_in_sec,
            num_workers=args.num_workers,
            out_dir=args.table_out_dir,
        )
        logging.info(f"Finish generating speech semgents table with postprocessing_params: {postprocessing_params}")
        end = time.time()
        logging.info(
            f"Generating rttm-like tables for {vad_pred_dir} takes {end-start:.2f} seconds!\n Save to {table_out_dir}"
        )