def main(): parser = ArgumentParser() parser.add_argument("--vad_model", type=str, default="MatchboxNet-VAD-3x2", required=False, help="Pass: '******'") parser.add_argument( "--dataset", type=str, required=True, help= "Path of json file of evaluation data. Audio files should have unique names.", ) parser.add_argument("--out_dir", type=str, default="vad_frame", help="Dir of your vad outputs") parser.add_argument("--time_length", type=float, default=0.63) parser.add_argument("--shift_length", type=float, default=0.01) parser.add_argument("--normalize_audio", type=bool, default=False) parser.add_argument("--num_workers", type=float, default=20) parser.add_argument("--split_duration", type=float, default=400) parser.add_argument( "--dont_auto_split", default=False, action='store_true', help= "Whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue.", ) args = parser.parse_args() torch.set_grad_enabled(False) if args.vad_model.endswith('.nemo'): logging.info(f"Using local VAD model from {args.vad_model}") vad_model = EncDecClassificationModel.restore_from( restore_path=args.vad_model) else: logging.info(f"Using NGC cloud VAD model {args.vad_model}") vad_model = EncDecClassificationModel.from_pretrained( model_name=args.vad_model) if not os.path.exists(args.out_dir): os.mkdir(args.out_dir) # Prepare manifest for streaming VAD manifest_vad_input = args.dataset if not args.dont_auto_split: logging.info("Split long audio file to avoid CUDA memory issue") logging.debug( "Try smaller split_duration if you still have CUDA memory issue") config = { 'manifest_filepath': manifest_vad_input, 'time_length': args.time_length, 'split_duration': args.split_duration, 'num_workers': args.num_workers, } manifest_vad_input = prepare_manifest(config) else: logging.warning( "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it." ) # setup_test_data vad_model.setup_test_data( test_data_config={ 'vad_stream': True, 'sample_rate': 16000, 'manifest_filepath': manifest_vad_input, 'labels': [ 'infer', ], 'num_workers': args.num_workers, 'shuffle': False, 'time_length': args.time_length, 'shift_length': args.shift_length, 'trim_silence': False, 'normalize_audio': args.normalize_audio, }) vad_model = vad_model.to(device) vad_model.eval() time_unit = int(args.time_length / args.shift_length) trunc = int(time_unit / 2) trunc_l = time_unit - trunc all_len = 0 data = [] for line in open(args.dataset, 'r'): file = json.loads(line)['audio_filepath'].split("/")[-1] data.append(file.split(".wav")[0]) logging.info(f"Inference on {len(data)} audio files/json lines!") status = get_vad_stream_status(data) for i, test_batch in enumerate(vad_model.test_dataloader()): test_batch = [x.to(device) for x in test_batch] with autocast(): log_probs = vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) probs = torch.softmax(log_probs, dim=-1) pred = probs[:, 1] if status[i] == 'start': to_save = pred[:-trunc] elif status[i] == 'next': to_save = pred[trunc:-trunc_l] elif status[i] == 'end': to_save = pred[trunc_l:] else: to_save = pred all_len += len(to_save) outpath = os.path.join(args.out_dir, data[i] + ".frame") with open(outpath, "a") as fout: for f in range(len(to_save)): fout.write('{0:0.4f}\n'.format(to_save[f])) del test_batch if status[i] == 'end' or status[i] == 'single': logging.debug( f"Overall length of prediction of {data[i]} is {all_len}!") all_len = 0
def _run_vad(self, manifest_file): self._vad_model = self._vad_model.to(self._device) self._vad_model.eval() time_unit = int(self._vad_window_length_in_sec / self._vad_shift_length_in_sec) trunc = int(time_unit / 2) trunc_l = time_unit - trunc all_len = 0 data = [] for line in open(manifest_file, 'r'): file = os.path.basename(json.loads(line)['audio_filepath']) data.append(os.path.splitext(file)[0]) status = get_vad_stream_status(data) for i, test_batch in enumerate(tqdm(self._vad_model.test_dataloader())): test_batch = [x.to(self._device) for x in test_batch] with autocast(): log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) probs = torch.softmax(log_probs, dim=-1) pred = probs[:, 1] if status[i] == 'start': to_save = pred[:-trunc] elif status[i] == 'next': to_save = pred[trunc:-trunc_l] elif status[i] == 'end': to_save = pred[trunc_l:] else: to_save = pred all_len += len(to_save) outpath = os.path.join(self._vad_dir, data[i] + ".frame") with open(outpath, "a") as fout: for f in range(len(to_save)): fout.write('{0:0.4f}\n'.format(to_save[f])) del test_batch if status[i] == 'end' or status[i] == 'single': all_len = 0 if not self._cfg.diarizer.vad.vad_decision_smoothing: # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame; self.vad_pred_dir = self._vad_dir else: # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. # smoothing_method would be either in majority vote (median) or average (mean) logging.info("Generating predictions with overlapping input segments") smoothing_pred_dir = generate_overlap_vad_seq( frame_pred_dir=self._vad_dir, smoothing_method=self._cfg.diarizer.vad.smoothing_params.method, overlap=self._cfg.diarizer.vad.smoothing_params.overlap, seg_len=self._vad_window_length_in_sec, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, ) self.vad_pred_dir = smoothing_pred_dir logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.") table_out_dir = generate_vad_segment_table( vad_pred_dir=self.vad_pred_dir, threshold=self._cfg.diarizer.vad.threshold, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, ) vad_table_list = [os.path.join(table_out_dir, key + ".txt") for key in self.AUDIO_RTTM_MAP] write_rttm2manifest(self._cfg.diarizer.paths2audio_files, vad_table_list, self._vad_out_file) self._speaker_manifest_path = self._vad_out_file