Beispiel #1
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--vad_model",
                        type=str,
                        default="MatchboxNet-VAD-3x2",
                        required=False,
                        help="Pass: '******'")
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        help=
        "Path of json file of evaluation data. Audio files should have unique names.",
    )
    parser.add_argument("--out_dir",
                        type=str,
                        default="vad_frame",
                        help="Dir of your vad outputs")
    parser.add_argument("--time_length", type=float, default=0.63)
    parser.add_argument("--shift_length", type=float, default=0.01)
    parser.add_argument("--normalize_audio", type=bool, default=False)
    parser.add_argument("--num_workers", type=float, default=20)
    parser.add_argument("--split_duration", type=float, default=400)
    parser.add_argument(
        "--dont_auto_split",
        default=False,
        action='store_true',
        help=
        "Whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue.",
    )

    args = parser.parse_args()

    torch.set_grad_enabled(False)

    if args.vad_model.endswith('.nemo'):
        logging.info(f"Using local VAD model from {args.vad_model}")
        vad_model = EncDecClassificationModel.restore_from(
            restore_path=args.vad_model)
    else:
        logging.info(f"Using NGC cloud VAD model {args.vad_model}")
        vad_model = EncDecClassificationModel.from_pretrained(
            model_name=args.vad_model)

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)

    # Prepare manifest for streaming VAD
    manifest_vad_input = args.dataset
    if not args.dont_auto_split:
        logging.info("Split long audio file to avoid CUDA memory issue")
        logging.debug(
            "Try smaller split_duration if you still have CUDA memory issue")
        config = {
            'manifest_filepath': manifest_vad_input,
            'time_length': args.time_length,
            'split_duration': args.split_duration,
            'num_workers': args.num_workers,
        }
        manifest_vad_input = prepare_manifest(config)
    else:
        logging.warning(
            "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it."
        )

    # setup_test_data
    vad_model.setup_test_data(
        test_data_config={
            'vad_stream': True,
            'sample_rate': 16000,
            'manifest_filepath': manifest_vad_input,
            'labels': [
                'infer',
            ],
            'num_workers': args.num_workers,
            'shuffle': False,
            'time_length': args.time_length,
            'shift_length': args.shift_length,
            'trim_silence': False,
            'normalize_audio': args.normalize_audio,
        })

    vad_model = vad_model.to(device)
    vad_model.eval()

    time_unit = int(args.time_length / args.shift_length)
    trunc = int(time_unit / 2)
    trunc_l = time_unit - trunc
    all_len = 0

    data = []
    for line in open(args.dataset, 'r'):
        file = json.loads(line)['audio_filepath'].split("/")[-1]
        data.append(file.split(".wav")[0])
    logging.info(f"Inference on {len(data)} audio files/json lines!")

    status = get_vad_stream_status(data)
    for i, test_batch in enumerate(vad_model.test_dataloader()):
        test_batch = [x.to(device) for x in test_batch]
        with autocast():
            log_probs = vad_model(input_signal=test_batch[0],
                                  input_signal_length=test_batch[1])
            probs = torch.softmax(log_probs, dim=-1)
            pred = probs[:, 1]

            if status[i] == 'start':
                to_save = pred[:-trunc]
            elif status[i] == 'next':
                to_save = pred[trunc:-trunc_l]
            elif status[i] == 'end':
                to_save = pred[trunc_l:]
            else:
                to_save = pred

            all_len += len(to_save)
            outpath = os.path.join(args.out_dir, data[i] + ".frame")
            with open(outpath, "a") as fout:
                for f in range(len(to_save)):
                    fout.write('{0:0.4f}\n'.format(to_save[f]))
        del test_batch
        if status[i] == 'end' or status[i] == 'single':
            logging.debug(
                f"Overall length of prediction of {data[i]} is {all_len}!")
            all_len = 0
Beispiel #2
0
    def _run_vad(self, manifest_file):
        self._vad_model = self._vad_model.to(self._device)
        self._vad_model.eval()

        time_unit = int(self._vad_window_length_in_sec /
                        self._vad_shift_length_in_sec)
        trunc = int(time_unit / 2)
        trunc_l = time_unit - trunc
        all_len = 0
        data = []
        for line in open(manifest_file, 'r'):
            file = os.path.basename(json.loads(line)['audio_filepath'])
            data.append(os.path.splitext(file)[0])

        status = get_vad_stream_status(data)
        for i, test_batch in enumerate(tqdm(
                self._vad_model.test_dataloader())):
            test_batch = [x.to(self._device) for x in test_batch]
            with autocast():
                log_probs = self._vad_model(input_signal=test_batch[0],
                                            input_signal_length=test_batch[1])
                probs = torch.softmax(log_probs, dim=-1)
                pred = probs[:, 1]
                if status[i] == 'start':
                    to_save = pred[:-trunc]
                elif status[i] == 'next':
                    to_save = pred[trunc:-trunc_l]
                elif status[i] == 'end':
                    to_save = pred[trunc_l:]
                else:
                    to_save = pred
                all_len += len(to_save)
                outpath = os.path.join(self._vad_dir, data[i] + ".frame")
                with open(outpath, "a") as fout:
                    for f in range(len(to_save)):
                        fout.write('{0:0.4f}\n'.format(to_save[f]))
            del test_batch
            if status[i] == 'end' or status[i] == 'single':
                all_len = 0

        if not self._cfg.diarizer.vad.vad_decision_smoothing:
            # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
            self.vad_pred_dir = self._vad_dir

        else:
            # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
            # smoothing_method would be either in majority vote (median) or average (mean)
            logging.info(
                "Generating predictions with overlapping input segments")
            smoothing_pred_dir = generate_overlap_vad_seq(
                frame_pred_dir=self._vad_dir,
                smoothing_method=self._cfg.diarizer.vad.smoothing_params.
                method,
                overlap=self._cfg.diarizer.vad.smoothing_params.overlap,
                seg_len=self._vad_window_length_in_sec,
                shift_len=self._vad_shift_length_in_sec,
                num_workers=self._cfg.num_workers,
            )
            self.vad_pred_dir = smoothing_pred_dir

        logging.info(
            "Converting frame level prediction to speech/no-speech segment in start and end times format."
        )
        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=self.vad_pred_dir,
            threshold=self._cfg.diarizer.vad.threshold,
            shift_len=self._vad_shift_length_in_sec,
            num_workers=self._cfg.num_workers,
        )

        vad_table_list = [
            os.path.join(table_out_dir, key + ".txt")
            for key in self.AUDIO_RTTM_MAP
        ]
        write_rttm2manifest(self._cfg.diarizer.paths2audio_files,
                            vad_table_list, self._vad_out_file)
        self._speaker_manifest_path = self._vad_out_file
Beispiel #3
0
    def _run_vad(self, manifest_file):
        """
        Run voice activity detection. 
        Get log probability of voice activity detection and smoothes using the post processing parameters. 
        Using generated frame level predictions generated manifest file for later speaker embedding extraction.
        input:
        manifest_file (str) : Manifest file containing path to audio file and label as infer

        """

        shutil.rmtree(self._vad_dir, ignore_errors=True)
        os.makedirs(self._vad_dir)

        self._vad_model = self._vad_model.to(self._device)
        self._vad_model.eval()

        time_unit = int(self._vad_window_length_in_sec /
                        self._vad_shift_length_in_sec)
        trunc = int(time_unit / 2)
        trunc_l = time_unit - trunc
        all_len = 0
        data = []
        for line in open(manifest_file, 'r'):
            file = json.loads(line)['audio_filepath']
            data.append(get_uniqname_from_filepath(file))

        status = get_vad_stream_status(data)
        for i, test_batch in enumerate(tqdm(
                self._vad_model.test_dataloader())):
            test_batch = [x.to(self._device) for x in test_batch]
            with autocast():
                log_probs = self._vad_model(input_signal=test_batch[0],
                                            input_signal_length=test_batch[1])
                probs = torch.softmax(log_probs, dim=-1)
                pred = probs[:, 1]
                if status[i] == 'start':
                    to_save = pred[:-trunc]
                elif status[i] == 'next':
                    to_save = pred[trunc:-trunc_l]
                elif status[i] == 'end':
                    to_save = pred[trunc_l:]
                else:
                    to_save = pred
                all_len += len(to_save)
                outpath = os.path.join(self._vad_dir, data[i] + ".frame")
                with open(outpath, "a") as fout:
                    for f in range(len(to_save)):
                        fout.write('{0:0.4f}\n'.format(to_save[f]))
            del test_batch
            if status[i] == 'end' or status[i] == 'single':
                all_len = 0

        if not self._vad_params.smoothing:
            # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
            self.vad_pred_dir = self._vad_dir
        else:
            # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
            # smoothing_method would be either in majority vote (median) or average (mean)
            logging.info(
                "Generating predictions with overlapping input segments")
            smoothing_pred_dir = generate_overlap_vad_seq(
                frame_pred_dir=self._vad_dir,
                smoothing_method=self._vad_params.smoothing,
                overlap=self._vad_params.overlap,
                seg_len=self._vad_window_length_in_sec,
                shift_len=self._vad_shift_length_in_sec,
                num_workers=self._cfg.num_workers,
            )
            self.vad_pred_dir = smoothing_pred_dir

        logging.info(
            "Converting frame level prediction to speech/no-speech segment in start and end times format."
        )

        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=self.vad_pred_dir,
            postprocessing_params=self._vad_params,
            shift_len=self._vad_shift_length_in_sec,
            num_workers=self._cfg.num_workers,
        )
        AUDIO_VAD_RTTM_MAP = deepcopy(self.AUDIO_RTTM_MAP.copy())
        for key in AUDIO_VAD_RTTM_MAP:
            AUDIO_VAD_RTTM_MAP[key]['rttm_filepath'] = os.path.join(
                table_out_dir, key + ".txt")

        write_rttm2manifest(AUDIO_VAD_RTTM_MAP, self._vad_out_file)
        self._speaker_manifest_path = self._vad_out_file