示例#1
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--vad_model",
                        type=str,
                        default="MatchboxNet-VAD-3x2",
                        required=False,
                        help="Pass: '******'")
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        help=
        "Path of json file of evaluation data. Audio files should have unique names.",
    )
    parser.add_argument("--out_dir",
                        type=str,
                        default="vad_frame",
                        help="Dir of your vad outputs")
    parser.add_argument("--time_length", type=float, default=0.63)
    parser.add_argument("--shift_length", type=float, default=0.01)
    parser.add_argument("--normalize_audio", type=bool, default=False)
    parser.add_argument("--num_workers", type=float, default=20)
    parser.add_argument("--split_duration", type=float, default=400)
    parser.add_argument(
        "--dont_auto_split",
        default=False,
        action='store_true',
        help=
        "Whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue.",
    )

    args = parser.parse_args()

    torch.set_grad_enabled(False)

    if args.vad_model.endswith('.nemo'):
        logging.info(f"Using local VAD model from {args.vad_model}")
        vad_model = EncDecClassificationModel.restore_from(
            restore_path=args.vad_model)
    else:
        logging.info(f"Using NGC cloud VAD model {args.vad_model}")
        vad_model = EncDecClassificationModel.from_pretrained(
            model_name=args.vad_model)

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)

    # Prepare manifest for streaming VAD
    manifest_vad_input = args.dataset
    if not args.dont_auto_split:
        logging.info("Split long audio file to avoid CUDA memory issue")
        logging.debug(
            "Try smaller split_duration if you still have CUDA memory issue")
        config = {
            'manifest_filepath': manifest_vad_input,
            'time_length': args.time_length,
            'split_duration': args.split_duration,
            'num_workers': args.num_workers,
        }
        manifest_vad_input = prepare_manifest(config)
    else:
        logging.warning(
            "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it."
        )

    # setup_test_data
    vad_model.setup_test_data(
        test_data_config={
            'vad_stream': True,
            'sample_rate': 16000,
            'manifest_filepath': manifest_vad_input,
            'labels': [
                'infer',
            ],
            'num_workers': args.num_workers,
            'shuffle': False,
            'time_length': args.time_length,
            'shift_length': args.shift_length,
            'trim_silence': False,
            'normalize_audio': args.normalize_audio,
        })

    vad_model = vad_model.to(device)
    vad_model.eval()

    time_unit = int(args.time_length / args.shift_length)
    trunc = int(time_unit / 2)
    trunc_l = time_unit - trunc
    all_len = 0

    data = []
    for line in open(args.dataset, 'r'):
        file = json.loads(line)['audio_filepath'].split("/")[-1]
        data.append(file.split(".wav")[0])
    logging.info(f"Inference on {len(data)} audio files/json lines!")

    status = get_vad_stream_status(data)
    for i, test_batch in enumerate(vad_model.test_dataloader()):
        test_batch = [x.to(device) for x in test_batch]
        with autocast():
            log_probs = vad_model(input_signal=test_batch[0],
                                  input_signal_length=test_batch[1])
            probs = torch.softmax(log_probs, dim=-1)
            pred = probs[:, 1]

            if status[i] == 'start':
                to_save = pred[:-trunc]
            elif status[i] == 'next':
                to_save = pred[trunc:-trunc_l]
            elif status[i] == 'end':
                to_save = pred[trunc_l:]
            else:
                to_save = pred

            all_len += len(to_save)
            outpath = os.path.join(args.out_dir, data[i] + ".frame")
            with open(outpath, "a") as fout:
                for f in range(len(to_save)):
                    fout.write('{0:0.4f}\n'.format(to_save[f]))
        del test_batch
        if status[i] == 'end' or status[i] == 'single':
            logging.debug(
                f"Overall length of prediction of {data[i]} is {all_len}!")
            all_len = 0
示例#2
0
    def _run_vad(self, manifest_file):
        self._vad_model = self._vad_model.to(self._device)
        self._vad_model.eval()

        time_unit = int(self._vad_window_length_in_sec / self._vad_shift_length_in_sec)
        trunc = int(time_unit / 2)
        trunc_l = time_unit - trunc
        all_len = 0
        data = []
        for line in open(manifest_file, 'r'):
            file = os.path.basename(json.loads(line)['audio_filepath'])
            data.append(os.path.splitext(file)[0])

        status = get_vad_stream_status(data)
        for i, test_batch in enumerate(tqdm(self._vad_model.test_dataloader())):
            test_batch = [x.to(self._device) for x in test_batch]
            with autocast():
                log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
                probs = torch.softmax(log_probs, dim=-1)
                pred = probs[:, 1]
                if status[i] == 'start':
                    to_save = pred[:-trunc]
                elif status[i] == 'next':
                    to_save = pred[trunc:-trunc_l]
                elif status[i] == 'end':
                    to_save = pred[trunc_l:]
                else:
                    to_save = pred
                all_len += len(to_save)
                outpath = os.path.join(self._vad_dir, data[i] + ".frame")
                with open(outpath, "a") as fout:
                    for f in range(len(to_save)):
                        fout.write('{0:0.4f}\n'.format(to_save[f]))
            del test_batch
            if status[i] == 'end' or status[i] == 'single':
                all_len = 0

        if not self._cfg.diarizer.vad.vad_decision_smoothing:
            # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
            self.vad_pred_dir = self._vad_dir

        else:
            # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
            # smoothing_method would be either in majority vote (median) or average (mean)
            logging.info("Generating predictions with overlapping input segments")
            smoothing_pred_dir = generate_overlap_vad_seq(
                frame_pred_dir=self._vad_dir,
                smoothing_method=self._cfg.diarizer.vad.smoothing_params.method,
                overlap=self._cfg.diarizer.vad.smoothing_params.overlap,
                seg_len=self._vad_window_length_in_sec,
                shift_len=self._vad_shift_length_in_sec,
                num_workers=self._cfg.num_workers,
            )
            self.vad_pred_dir = smoothing_pred_dir

        logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")
        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=self.vad_pred_dir,
            threshold=self._cfg.diarizer.vad.threshold,
            shift_len=self._vad_shift_length_in_sec,
            num_workers=self._cfg.num_workers,
        )

        vad_table_list = [os.path.join(table_out_dir, key + ".txt") for key in self.AUDIO_RTTM_MAP]
        write_rttm2manifest(self._cfg.diarizer.paths2audio_files, vad_table_list, self._vad_out_file)
        self._speaker_manifest_path = self._vad_out_file