def main(): parser = ArgumentParser() parser.add_argument("--vad_model", type=str, default="MatchboxNet-VAD-3x2", required=False, help="Pass: '******'") parser.add_argument( "--dataset", type=str, required=True, help= "Path of json file of evaluation data. Audio files should have unique names.", ) parser.add_argument("--out_dir", type=str, default="vad_frame", help="Dir of your vad outputs") parser.add_argument("--time_length", type=float, default=0.63) parser.add_argument("--shift_length", type=float, default=0.01) parser.add_argument("--normalize_audio", type=bool, default=False) parser.add_argument("--num_workers", type=float, default=20) parser.add_argument("--split_duration", type=float, default=400) parser.add_argument( "--dont_auto_split", default=False, action='store_true', help= "Whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue.", ) args = parser.parse_args() torch.set_grad_enabled(False) if args.vad_model.endswith('.nemo'): logging.info(f"Using local VAD model from {args.vad_model}") vad_model = EncDecClassificationModel.restore_from( restore_path=args.vad_model) else: logging.info(f"Using NGC cloud VAD model {args.vad_model}") vad_model = EncDecClassificationModel.from_pretrained( model_name=args.vad_model) if not os.path.exists(args.out_dir): os.mkdir(args.out_dir) # Prepare manifest for streaming VAD manifest_vad_input = args.dataset if not args.dont_auto_split: logging.info("Split long audio file to avoid CUDA memory issue") logging.debug( "Try smaller split_duration if you still have CUDA memory issue") config = { 'manifest_filepath': manifest_vad_input, 'time_length': args.time_length, 'split_duration': args.split_duration, 'num_workers': args.num_workers, } manifest_vad_input = prepare_manifest(config) else: logging.warning( "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it." ) # setup_test_data vad_model.setup_test_data( test_data_config={ 'vad_stream': True, 'sample_rate': 16000, 'manifest_filepath': manifest_vad_input, 'labels': [ 'infer', ], 'num_workers': args.num_workers, 'shuffle': False, 'time_length': args.time_length, 'shift_length': args.shift_length, 'trim_silence': False, 'normalize_audio': args.normalize_audio, }) vad_model = vad_model.to(device) vad_model.eval() time_unit = int(args.time_length / args.shift_length) trunc = int(time_unit / 2) trunc_l = time_unit - trunc all_len = 0 data = [] for line in open(args.dataset, 'r'): file = json.loads(line)['audio_filepath'].split("/")[-1] data.append(file.split(".wav")[0]) logging.info(f"Inference on {len(data)} audio files/json lines!") status = get_vad_stream_status(data) for i, test_batch in enumerate(vad_model.test_dataloader()): test_batch = [x.to(device) for x in test_batch] with autocast(): log_probs = vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) probs = torch.softmax(log_probs, dim=-1) pred = probs[:, 1] if status[i] == 'start': to_save = pred[:-trunc] elif status[i] == 'next': to_save = pred[trunc:-trunc_l] elif status[i] == 'end': to_save = pred[trunc_l:] else: to_save = pred all_len += len(to_save) outpath = os.path.join(args.out_dir, data[i] + ".frame") with open(outpath, "a") as fout: for f in range(len(to_save)): fout.write('{0:0.4f}\n'.format(to_save[f])) del test_batch if status[i] == 'end' or status[i] == 'single': logging.debug( f"Overall length of prediction of {data[i]} is {all_len}!") all_len = 0
def _run_vad(self, manifest_file): self._vad_model = self._vad_model.to(self._device) self._vad_model.eval() time_unit = int(self._vad_window_length_in_sec / self._vad_shift_length_in_sec) trunc = int(time_unit / 2) trunc_l = time_unit - trunc all_len = 0 data = [] for line in open(manifest_file, 'r'): file = os.path.basename(json.loads(line)['audio_filepath']) data.append(os.path.splitext(file)[0]) status = get_vad_stream_status(data) for i, test_batch in enumerate(tqdm( self._vad_model.test_dataloader())): test_batch = [x.to(self._device) for x in test_batch] with autocast(): log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) probs = torch.softmax(log_probs, dim=-1) pred = probs[:, 1] if status[i] == 'start': to_save = pred[:-trunc] elif status[i] == 'next': to_save = pred[trunc:-trunc_l] elif status[i] == 'end': to_save = pred[trunc_l:] else: to_save = pred all_len += len(to_save) outpath = os.path.join(self._vad_dir, data[i] + ".frame") with open(outpath, "a") as fout: for f in range(len(to_save)): fout.write('{0:0.4f}\n'.format(to_save[f])) del test_batch if status[i] == 'end' or status[i] == 'single': all_len = 0 if not self._cfg.diarizer.vad.vad_decision_smoothing: # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame; self.vad_pred_dir = self._vad_dir else: # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. # smoothing_method would be either in majority vote (median) or average (mean) logging.info( "Generating predictions with overlapping input segments") smoothing_pred_dir = generate_overlap_vad_seq( frame_pred_dir=self._vad_dir, smoothing_method=self._cfg.diarizer.vad.smoothing_params. method, overlap=self._cfg.diarizer.vad.smoothing_params.overlap, seg_len=self._vad_window_length_in_sec, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, ) self.vad_pred_dir = smoothing_pred_dir logging.info( "Converting frame level prediction to speech/no-speech segment in start and end times format." ) table_out_dir = generate_vad_segment_table( vad_pred_dir=self.vad_pred_dir, threshold=self._cfg.diarizer.vad.threshold, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, ) vad_table_list = [ os.path.join(table_out_dir, key + ".txt") for key in self.AUDIO_RTTM_MAP ] write_rttm2manifest(self._cfg.diarizer.paths2audio_files, vad_table_list, self._vad_out_file) self._speaker_manifest_path = self._vad_out_file
def _run_vad(self, manifest_file): """ Run voice activity detection. Get log probability of voice activity detection and smoothes using the post processing parameters. Using generated frame level predictions generated manifest file for later speaker embedding extraction. input: manifest_file (str) : Manifest file containing path to audio file and label as infer """ shutil.rmtree(self._vad_dir, ignore_errors=True) os.makedirs(self._vad_dir) self._vad_model = self._vad_model.to(self._device) self._vad_model.eval() time_unit = int(self._vad_window_length_in_sec / self._vad_shift_length_in_sec) trunc = int(time_unit / 2) trunc_l = time_unit - trunc all_len = 0 data = [] for line in open(manifest_file, 'r'): file = json.loads(line)['audio_filepath'] data.append(get_uniqname_from_filepath(file)) status = get_vad_stream_status(data) for i, test_batch in enumerate(tqdm( self._vad_model.test_dataloader())): test_batch = [x.to(self._device) for x in test_batch] with autocast(): log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) probs = torch.softmax(log_probs, dim=-1) pred = probs[:, 1] if status[i] == 'start': to_save = pred[:-trunc] elif status[i] == 'next': to_save = pred[trunc:-trunc_l] elif status[i] == 'end': to_save = pred[trunc_l:] else: to_save = pred all_len += len(to_save) outpath = os.path.join(self._vad_dir, data[i] + ".frame") with open(outpath, "a") as fout: for f in range(len(to_save)): fout.write('{0:0.4f}\n'.format(to_save[f])) del test_batch if status[i] == 'end' or status[i] == 'single': all_len = 0 if not self._vad_params.smoothing: # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame; self.vad_pred_dir = self._vad_dir else: # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. # smoothing_method would be either in majority vote (median) or average (mean) logging.info( "Generating predictions with overlapping input segments") smoothing_pred_dir = generate_overlap_vad_seq( frame_pred_dir=self._vad_dir, smoothing_method=self._vad_params.smoothing, overlap=self._vad_params.overlap, seg_len=self._vad_window_length_in_sec, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, ) self.vad_pred_dir = smoothing_pred_dir logging.info( "Converting frame level prediction to speech/no-speech segment in start and end times format." ) table_out_dir = generate_vad_segment_table( vad_pred_dir=self.vad_pred_dir, postprocessing_params=self._vad_params, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, ) AUDIO_VAD_RTTM_MAP = deepcopy(self.AUDIO_RTTM_MAP.copy()) for key in AUDIO_VAD_RTTM_MAP: AUDIO_VAD_RTTM_MAP[key]['rttm_filepath'] = os.path.join( table_out_dir, key + ".txt") write_rttm2manifest(AUDIO_VAD_RTTM_MAP, self._vad_out_file) self._speaker_manifest_path = self._vad_out_file