def add_transcript_to_manifest(manifest_original: str, manifest_updated: str, asr_model: nemo_asr.models.EncDecCTCModel, batch_size: int) -> None: """ Adds transcripts generated by the asr_model to the manifest_original. Args: manifest_original: path to the manifest manifest_updated: path to the updated manifest with transcript included asr_model: CTC-based ASR model, for example, QuartzNet15x5Base-En batch_size: Batch size for asr_model inference """ transcripts = get_transcript(manifest_original, asr_model, batch_size) with open(manifest_original, 'r', encoding='utf8') as f: with open(manifest_updated, 'w', encoding='utf8') as f_updated: for i, line in enumerate(f): info = json.loads(line) info['pred_text'] = transcripts[i].strip() info['WER'] = round( word_error_rate([info['pred_text']], [info['text']]) * 100, 2) info['CER'] = round( word_error_rate([info['pred_text']], [info['text']], use_cer=True) * 100, 2) json.dump(info, f_updated, ensure_ascii=False) f_updated.write('\n')
def test_wer_metric_randomized(self, test_wer_bpe): """This test relies on correctness of word_error_rate function.""" def __random_string(length): return ''.join( random.choice(''.join(self.vocabulary)) for _ in range(length)) if test_wer_bpe: wer = WERBPE(deepcopy(self.char_tokenizer), batch_dim_index=0, use_cer=False, ctc_decode=True) else: wer = WER(vocabulary=self.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True) for test_id in range(256): n1 = random.randint(1, 512) n2 = random.randint(1, 512) s1 = __random_string(n1) s2 = __random_string(n2) # skip empty strings as reference if s2.strip(): assert (abs( self.get_wer(wer, prediction=s1, reference=s2, use_tokenizer=test_wer_bpe) - word_error_rate(hypotheses=[s1], references=[s2])) < 1e-6)
def get_wer_feat(mfst, asr, tokens_per_chunk, delay, model_stride_in_secs, batch_size): hyps = [] refs = [] audio_filepaths = [] with open(mfst, "r") as mfst_f: print("Parsing manifest files...") for l in mfst_f: row = json.loads(l.strip()) audio_filepaths.append(row['audio_filepath']) refs.append(row['text']) with torch.inference_mode(): with torch.cuda.amp.autocast(): batch = [] asr.sample_offset = 0 for idx in tqdm.tqdm(range(len(audio_filepaths)), desc='Sample:', total=len(audio_filepaths)): batch.append((audio_filepaths[idx], refs[idx])) if len(batch) == batch_size: audio_files = [sample[0] for sample in batch] asr.reset() asr.read_audio_file(audio_files, delay, model_stride_in_secs) hyp_list = asr.transcribe(tokens_per_chunk, delay) hyps.extend(hyp_list) batch.clear() asr.sample_offset += batch_size if len(batch) > 0: asr.batch_size = len(batch) asr.frame_bufferer.batch_size = len(batch) asr.reset() audio_files = [sample[0] for sample in batch] asr.read_audio_file(audio_files, delay, model_stride_in_secs) hyp_list = asr.transcribe(tokens_per_chunk, delay) hyps.extend(hyp_list) batch.clear() asr.sample_offset += len(batch) if os.environ.get('DEBUG', '0') in ('1', 'y', 't'): for hyp, ref in zip(hyps, refs): print("hyp:", hyp) print("ref:", ref) wer = word_error_rate(hypotheses=hyps, references=refs) return hyps, refs, wer
def test_wer_function(self): assert word_error_rate(hypotheses=['cat'], references=['cot']) == 1.0 assert word_error_rate(hypotheses=['GPU'], references=['G P U']) == 1.0 assert word_error_rate(hypotheses=['G P U'], references=['GPU']) == 3.0 assert word_error_rate(hypotheses=['ducati motorcycle'], references=['motorcycle']) == 1.0 assert word_error_rate(hypotheses=['ducati motorcycle'], references=['ducuti motorcycle']) == 0.5 assert word_error_rate(hypotheses=['a B c'], references=['a b c']) == 1.0 / 3.0
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument( "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English." ) args = parser.parse_args() torch.set_grad_enabled(False) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.normalize_text, } ) if can_gpu: asr_model = asr_model.cuda() asr_model.eval() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary) hypotheses = [] references = [] for test_batch in asr_model.test_dataloader(): if can_gpu: test_batch = [x.cuda() for x in test_batch] with autocast(): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=test_batch[0], input_signal_length=test_batch[1] ) hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions) for batch_ind in range(greedy_predictions.shape[0]): reference = ''.join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()]) references.append(reference) del test_batch wer_value = word_error_rate(hypotheses=hypotheses, references=references) if wer_value > args.wer_tolerance: raise ValueError(f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}") logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
def evaluate(asr_model, asr_onnx, labels_map, wer, qat): # Eval the model hypotheses = [] references = [] stream = cuda.Stream() vocabulary_size = len(labels_map) + 1 engine_file_path = build_trt_engine(asr_model, asr_onnx, qat) with open(engine_file_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime: trt_engine = runtime.deserialize_cuda_engine(f.read()) trt_ctx = trt_engine.create_execution_context() profile_shape = trt_engine.get_profile_shape(profile_index=0, binding=0) print("profile shape min:{}, opt:{}, max:{}".format( profile_shape[0], profile_shape[1], profile_shape[2])) max_input_shape = profile_shape[2] input_nbytes = trt.volume(max_input_shape) * trt.float32.itemsize d_input = cuda.mem_alloc(input_nbytes) max_output_shape = [ max_input_shape[0], vocabulary_size, (max_input_shape[-1] + 1) // 2 ] output_nbytes = trt.volume(max_output_shape) * trt.float32.itemsize d_output = cuda.mem_alloc(output_nbytes) for test_batch in asr_model.test_dataloader(): if can_gpu: test_batch = [x.cuda() for x in test_batch] processed_signal, processed_signal_length = asr_model.preprocessor( input_signal=test_batch[0], length=test_batch[1]) greedy_predictions = trt_inference( stream, trt_ctx, d_input, d_output, input_signal=processed_signal, input_signal_length=processed_signal_length, ) hypotheses += wer.ctc_decoder_predictions_tensor( greedy_predictions) for batch_ind in range(greedy_predictions.shape[0]): seq_len = test_batch[3][batch_ind].cpu().detach().numpy() seq_ids = test_batch[2][batch_ind].cpu().detach().numpy() reference = ''.join( [labels_map[c] for c in seq_ids[0:seq_len]]) references.append(reference) del test_batch wer_value = word_error_rate(hypotheses=hypotheses, references=references, use_cer=wer.use_cer) return wer_value
def test_wer_metric_randomized(self): """This test relies on correctness of word_error_rate function.""" def __randomString(N): return ''.join(random.choice(''.join(self.vocabulary)) for i in range(N)) wer = WER(vocabulary=self.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True) for test_id in range(256): n1 = random.randint(1, 512) n2 = random.randint(1, 512) s1 = __randomString(n1) s2 = __randomString(n2) # Floating-point math doesn't seem to be an issue here. Leaving as == assert self.get_wer(wer, prediction=s1, reference=s2) == word_error_rate(hypotheses=[s1], references=[s2])
def evaluate(asr_model, labels_map, wer): # Eval the model hypotheses = [] references = [] for test_batch in asr_model.test_dataloader(): if can_gpu: test_batch = [x.cuda() for x in test_batch] with autocast(): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=test_batch[0], input_signal_length=test_batch[1] ) hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions) for batch_ind in range(greedy_predictions.shape[0]): reference = ''.join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()]) references.append(reference) del test_batch wer_value = word_error_rate(hypotheses=hypotheses, references=references) return wer_value
def calculate_cer(normalized_texts: List[str], transcript: str, remove_punct=False) -> List[Tuple[str, float]]: """ Calculates character error rate (CER) Args: normalized_texts: normalized text options transcript: ASR model output Returns: normalized options with corresponding CER """ normalized_options = [] for text in normalized_texts: text_clean = text.replace('-', ' ').lower() if remove_punct: for punct in "!?:;,.-()*+-/<=>@^_": text_clean = text_clean.replace(punct, "") cer = round(word_error_rate([transcript], [text_clean], use_cer=True) * 100, 2) normalized_options.append((text, cer)) return normalized_options
def test_rnnt_wer_metric_randomized(self, test_wer_bpe): """This test relies on correctness of word_error_rate function.""" def __random_string(length): return ''.join( random.choice(''.join(self.vocabulary)) for _ in range(length)) for test_id in range(256): n1 = random.randint(1, 512) n2 = random.randint(1, 512) s1 = __random_string(n1) s2 = __random_string(n2) # skip empty strings as reference if s2.strip(): assert (abs( self.get_wer_rnnt(prediction=s1, reference=s2, batch_dim_index=0, test_wer_bpe=test_wer_bpe) - word_error_rate(hypotheses=[s1], references=[s2])) < 1e-6)
def get_wder_dict_values(self, asr_eval_dict, wder_dict, count_dict, align_error_list): """ Calculate the total error rates for WDER, WER and alignment error. """ if '-' in asr_eval_dict['references_list'] or None in asr_eval_dict['references_list']: wer = -1 else: wer = word_error_rate( hypotheses=asr_eval_dict['hypotheses_list'], references=asr_eval_dict['references_list'] ) wder_dict['total_WER'] = wer wder_dict['total_wder_rttm'] = 1 - ( count_dict['grand_total_correct_word_count'] / count_dict['grand_total_pred_word_count'] ) if all(x for x in self.ctm_exists.values()) == True: wder_dict['total_wder_ctm_ref_trans'] = ( count_dict['total_ctm_wder_count'] / count_dict['grand_total_ctm_word_count'] if count_dict['grand_total_ctm_word_count'] > 0 else -1 ) wder_dict['total_wder_ctm_pred_asr'] = ( count_dict['total_ctm_wder_count'] / count_dict['grand_total_pred_word_count'] if count_dict['grand_total_pred_word_count'] > 0 else -1 ) wder_dict['total_diar_trans_acc'] = ( count_dict['total_asr_and_spk_correct_words'] / count_dict['grand_total_ctm_word_count'] if count_dict['grand_total_ctm_word_count'] > 0 else -1 ) wder_dict['total_alignment_error_mean'] = ( np.mean(self.align_error_list).round(4) if self.align_error_list != [] else -1 ) wder_dict['total_alignment_error_std'] = ( np.std(self.align_error_list).round(4) if self.align_error_list != [] else -1 ) return wder_dict
def batch_inference(args: argparse.Namespace): torch.set_grad_enabled(False) if args.asr_model.endswith(".nemo"): print(f"Using local ASR model from {args.asr_model}") asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model) else: print(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) manifest = prepare_manifest(args.corpora_dir, args.limit) asr_model.setup_test_data( test_data_config={ "sample_rate": 16000, "manifest_filepath": manifest, "labels": asr_model.decoder.vocabulary, "batch_size": args.batch_size, "normalize_transcripts": args.normalize_text, }) refs_hyps = list(tqdm(generate_ref_hyps(asr_model, args.search, args.arpa))) references, hypotheses = [list(k) for k in zip(*refs_hyps)] os.makedirs(args.results_dir, exist_ok=True) data_io.write_lines(f"{args.results_dir}/refs.txt.gz", references) data_io.write_lines(f"{args.results_dir}/hyps.txt.gz", hypotheses) wer_value = word_error_rate(hypotheses=hypotheses, references=references) sys.stdout.flush() stats = { "wer": wer_value, "args": args.__dict__, } data_io.write_json(f"{args.results_dir}/stats.txt", stats) print(f"Got WER of {wer_value}") return stats
def main(): args = parse_arguments() # Instantiate pytorch model nemo_model = args.nemo_model nemo_model = ASRModel.restore_from(nemo_model, map_location='cpu') # type: ASRModel nemo_model.freeze() if torch.cuda.is_available(): nemo_model = nemo_model.to('cuda') export_model_if_required(args, nemo_model) # Instantiate RNNT Decoding loop encoder_model = args.onnx_encoder decoder_model = args.onnx_decoder max_symbols_per_step = args.max_symbold_per_step decoding = ONNXGreedyBatchedRNNTInfer(encoder_model, decoder_model, max_symbols_per_step) audio_filepath = resolve_audio_filepaths(args) # Evaluate Pytorch Model (CPU/GPU) actual_transcripts = nemo_model.transcribe(audio_filepath, batch_size=args.batch_size)[0] # Evaluate ONNX model (on CPU) with tempfile.TemporaryDirectory() as tmpdir: with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: for audio_file in audio_filepath: entry = { 'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing' } fp.write(json.dumps(entry) + '\n') config = { 'paths2audio_files': audio_filepath, 'batch_size': args.batch_size, 'temp_dir': tmpdir } # Push nemo model to CPU nemo_model = nemo_model.to('cpu') nemo_model.preprocessor.featurizer.dither = 0.0 nemo_model.preprocessor.featurizer.pad_to = 0 temporary_datalayer = nemo_model._setup_transcribe_dataloader(config) all_hypothesis = [] for test_batch in tqdm(temporary_datalayer, desc="ONNX Transcribing"): input_signal, input_signal_length = test_batch[0], test_batch[1] # Acoustic features processed_audio, processed_audio_len = nemo_model.preprocessor( input_signal=input_signal, length=input_signal_length) # RNNT Decoding loop hypotheses = decoding(audio_signal=processed_audio, length=processed_audio_len) # Process hypothesis (map char/subword token ids to text) hypotheses = nemo_model.decoding.decode_hypothesis( hypotheses) # type: List[str] # Extract text from the hypothesis texts = [h.text for h in hypotheses] all_hypothesis += texts del processed_audio, processed_audio_len del test_batch if args.log: for pt_transcript, onnx_transcript in zip(actual_transcripts, all_hypothesis): print(f"Pytorch Transcripts : {pt_transcript}") print(f"ONNX Transcripts : {onnx_transcript}") print() # Measure error rate between onnx and pytorch transcipts pt_onnx_cer = word_error_rate(all_hypothesis, actual_transcripts, use_cer=True) assert pt_onnx_cer < args.threshold, "Threshold violation !" print("Character error rate between Pytorch and ONNX :", pt_onnx_cer)
def main(cfg: EvaluationConfig): torch.set_grad_enabled(False) if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) if cfg.audio_dir is not None: raise RuntimeError( "Evaluation script requires ground truth labels to be passed via a manifest file. " "If manifest file is available, submit it via `dataset_manifest` argument." ) if not os.path.exists(cfg.dataset_manifest): raise FileNotFoundError( f"The dataset manifest file could not be found at path : {cfg.dataset_manifest}" ) if not cfg.only_score_manifest: # Transcribe speech into an output directory transcription_cfg = transcribe_speech.main( cfg) # type: EvaluationConfig # Release GPU memory if it was used during transcription if torch.cuda.is_available(): torch.cuda.empty_cache() logging.info( "Finished transcribing speech dataset. Computing ASR metrics..") else: cfg.output_filename = cfg.dataset_manifest transcription_cfg = cfg ground_truth_text = [] predicted_text = [] invalid_manifest = False with open(transcription_cfg.output_filename, 'r') as f: for line in f: data = json.loads(line) if 'pred_text' not in data: invalid_manifest = True break ground_truth_text.append(data['text']) predicted_text.append(data['pred_text']) # Test for invalid manifest supplied if invalid_manifest: raise ValueError( f"Invalid manifest provided: {transcription_cfg.output_filename} does not " f"contain value for `pred_text`.") # Compute the WER metric_name = 'CER' if cfg.use_cer else 'WER' metric_value = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=cfg.use_cer) if cfg.tolerance is not None: if metric_value > cfg.tolerance: raise ValueError( f"Got {metric_name} of {metric_value}, which was higher than tolerance={cfg.tolerance}" ) logging.info( f'Got {metric_name} of {metric_value}. Tolerance was {cfg.tolerance}' ) else: logging.info(f'Got {metric_name} of {metric_value}') # Inject the metric name and score into the config, and return the entire config with open_dict(cfg): cfg.metric_name = metric_name cfg.metric_value = metric_value return cfg
def ASR_Grade(dataset, id, key): try: from torch.cuda.amp import autocast except ImportError: from contextlib import contextmanager @contextmanager def autocast(enabled=None): yield can_gpu = torch.cuda.is_available() parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default=model_Selected, required=True, help=f'Pass: {model_Selected}', ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument( "--normalize_text", default=False, # False <- we're using phonetic references type=bool, help="Normalize transcripts or not. Set to False for non-English.", ) args = parser.parse_args( ["--dataset", dataset, "--asr_model", model_Selected]) torch.set_grad_enabled(False) # Instantiate Jasper/QuartzNet models with the EncDecCTCModel class. asr_model = EncDecCTCModel.restore_from(model_Path) asr_model.setup_test_data( test_data_config={ "sample_rate": 16000, "manifest_filepath": args.dataset, "labels": asr_model.decoder.vocabulary, "batch_size": args.batch_size, "normalize_transcripts": args.normalize_text, }) if can_gpu: # noqa asr_model = asr_model.cuda() asr_model.eval() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary) hypotheses = [] references = [] for test_batch in asr_model.test_dataloader(): if can_gpu: test_batch = [x.cuda() for x in test_batch] with autocast(): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=test_batch[0], input_signal_length=test_batch[1]) hypotheses = wer.ctc_decoder_predictions_tensor(greedy_predictions) for batch_ind in range(greedy_predictions.shape[0]): reference = key #reference = "".join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()]) #debug print(reference) #debug references.append(reference) del test_batch wer_value = word_error_rate(hypotheses=hypotheses, references=references) #cer=True REC = '.' REF = '.' for h, r in zip(hypotheses, references): print("Recognized:\t{}\nReference:\t{}\n".format(h, r)) REC = h REF = r logging.info(f"Got PER of {wer_value}. Tolerance was {args.wer_tolerance}") #Score Calculation, phoneme conversion # divide wer_value by wer_tolerance to get the ratio of correctness (and round it) # then multiply by 100 to get a value above 0 # since this give the "% wrong", subtract from 100 to get "% correct" # this gives a positive grade to show return to the user score = 100.00 - (round((wer_value / args.wer_tolerance), 4) * 100) if score < 0.0: score = 0.0 print(score) #Result file creation, to be accessed by JS via 'app.py' Results = open(datasetPath + id + '_graded.txt', 'w') Results.write(REC + '\n' + REF + '\n' + str(score)) Results.close() return score
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", choices=[ x.pretrained_model_name for x in EncDecCTCModel.list_available_models() ], ) parser.add_argument( "--tts_model_spec", type=str, default="Tacotron2-22050Hz", choices=[ x.pretrained_model_name for x in SpectrogramGenerator.list_available_models() ], ) parser.add_argument( "--tts_model_vocoder", type=str, default="WaveGlow-22050Hz", choices=[ x.pretrained_model_name for x in Vocoder.list_available_models() ], ) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument("--trim", action="store_true") parser.add_argument("--debug", action="store_true") args = parser.parse_args() torch.set_grad_enabled(False) if args.debug: logging.set_verbosity(logging.DEBUG) logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) logging.info( f"Using NGC cloud TTS Spectrogram Generator model {args.tts_model_spec}" ) tts_model_spec = SpectrogramGenerator.from_pretrained( model_name=args.tts_model_spec) logging.info(f"Using NGC cloud TTS Vocoder model {args.tts_model_vocoder}") tts_model_vocoder = Vocoder.from_pretrained( model_name=args.tts_model_vocoder) models = [asr_model, tts_model_spec, tts_model_vocoder] if torch.cuda.is_available(): for i, m in enumerate(models): models[i] = m.cuda() for m in models: m.eval() asr_model, tts_model_spec, tts_model_vocoder = models parser = parsers.make_parser( labels=asr_model.decoder.vocabulary, name="en", unk_id=-1, blank_id=-1, do_normalize=True, ) labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) tts_input = [] asr_references = [] longest_tts_input = 0 for test_str in LIST_OF_TEST_STRINGS: tts_parsed_input = tts_model_spec.parse(test_str) if len(tts_parsed_input[0]) > longest_tts_input: longest_tts_input = len(tts_parsed_input[0]) tts_input.append(tts_parsed_input.squeeze()) asr_parsed = parser(test_str) asr_parsed = ''.join([labels_map[c] for c in asr_parsed]) asr_references.append(asr_parsed) # Pad TTS Inputs for i, text in enumerate(tts_input): pad = (0, longest_tts_input - len(text)) tts_input[i] = torch.nn.functional.pad(text, pad, value=68) logging.debug(tts_input) # Do TTS tts_input = torch.stack(tts_input) if torch.cuda.is_available(): tts_input = tts_input.cuda() specs = tts_model_spec.generate_spectrogram(tokens=tts_input) audio = [] step = ceil(len(specs) / 4) for i in range(4): audio.append( tts_model_vocoder.convert_spectrogram_to_audio( spec=specs[i * step:i * step + step])) audio = [item for sublist in audio for item in sublist] audio_file_paths = [] # Save audio logging.debug(f"args.trim: {args.trim}") for i, aud in enumerate(audio): aud = aud.cpu().numpy() if args.trim: aud = librosa.effects.trim(aud, top_db=40)[0] librosa.output.write_wav(f"{i}.wav", aud, sr=22050) audio_file_paths.append(str(Path(f"{i}.wav"))) # Do ASR hypotheses = asr_model.transcribe(audio_file_paths) for i, _ in enumerate(hypotheses): logging.debug(f"{i}") logging.debug(f"ref:'{asr_references[i]}'") logging.debug(f"hyp:'{hypotheses[i]}'") wer_value = word_error_rate(hypotheses=hypotheses, references=asr_references) if wer_value > args.wer_tolerance: raise ValueError( f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}") logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
def main(cfg: ParallelTranscriptionConfig): if cfg.model.endswith(".nemo"): logging.info("Attempting to initialize from .nemo file") model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu") elif cfg.model.endswith(".ckpt"): logging.info("Attempting to initialize from .ckpt file") model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu") else: logging.info( "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt" ) model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu") trainer = ptl.Trainer(**cfg.trainer) cfg.predict_ds.return_sample_id = True cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds) data_loader = model._setup_dataloader_from_config(cfg.predict_ds) os.makedirs(cfg.output_path, exist_ok=True) # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank. global_rank = trainer.node_rank * trainer.num_gpus + int( os.environ.get("LOCAL_RANK", 0)) output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json") predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset, output_file=output_file) trainer.callbacks.extend([predictor_writer]) predictions = trainer.predict(model=model, dataloaders=data_loader, return_predictions=cfg.return_predictions) if predictions is not None: predictions = list(itertools.chain.from_iterable(predictions)) samples_num = predictor_writer.close_output_file() logging.info( f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}." ) if torch.distributed.is_initialized(): torch.distributed.barrier() samples_num = 0 pred_text_list = [] text_list = [] if is_global_rank_zero(): output_file = os.path.join(cfg.output_path, f"predictions_all.json") logging.info( f"Prediction files are being aggregated in {output_file}.") with open(output_file, 'w') as outf: for rank in range(trainer.world_size): input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json") with open(input_file, 'r') as inpf: lines = inpf.readlines() for line in lines: item = json.loads(line) pred_text_list.append(item["pred_text"]) text_list.append(item["text"]) outf.write(json.dumps(item) + "\n") samples_num += 1 wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer) logging.info( f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}." ) logging.info("{} for all predictions is {:.4f}.".format( "CER" if cfg.use_cer else "WER", wer_cer))
def main(args): torch.set_grad_enabled(False) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model = nemo_asr.models.EncDecCTCModelBPE.restore_from( restore_path=args.asr_model) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model) cfg = copy.deepcopy(asr_model._cfg) OmegaConf.set_struct(cfg.preprocessor, False) # some changes for streaming scenario cfg.preprocessor.dither = 0.0 cfg.preprocessor.pad_to = 0 if cfg.preprocessor.normalize != "per_feature": logging.error( "Only EncDecRNNTBPEModel models trained with per_feature normalization are supported currently" ) device = args.device if device is None: if torch.cuda.is_available(): device = 'cuda' else: device = 'cpu' logging.info(f"Inference will be done on device : {device}") # Disable config overwriting OmegaConf.set_struct(cfg.preprocessor, True) asr_model.freeze() asr_model = asr_model.to(device) # Change Decoding Config decoding_cfg = asr_model.cfg.decoding with open_dict(decoding_cfg): if args.stateful_decoding: decoding_cfg.strategy = "greedy" else: decoding_cfg.strategy = "greedy_batch" decoding_cfg.preserve_alignments = True # required to compute the middle token for transducers. decoding_cfg.fused_batch_size = -1 # temporarily stop fused batch during inference. asr_model.change_decoding_strategy(decoding_cfg) feature_stride = cfg.preprocessor['window_stride'] model_stride_in_secs = feature_stride * args.model_stride total_buffer = args.total_buffer_in_secs chunk_len = float(args.chunk_len_in_secs) tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs) mid_delay = math.ceil( (chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs) print("Tokens per chunk :", tokens_per_chunk, "Min Delay :", mid_delay) if args.merge_algo == 'middle': frame_asr = BatchedFrameASRRNNT( asr_model=asr_model, frame_len=chunk_len, total_buffer=args.total_buffer_in_secs, batch_size=args.batch_size, max_steps_per_timestep=args.max_steps_per_timestep, stateful_decoding=args.stateful_decoding, ) elif args.merge_algo == 'lcs': frame_asr = LongestCommonSubsequenceBatchedFrameASRRNNT( asr_model=asr_model, frame_len=chunk_len, total_buffer=args.total_buffer_in_secs, batch_size=args.batch_size, max_steps_per_timestep=args.max_steps_per_timestep, stateful_decoding=args.stateful_decoding, alignment_basepath=args.lcs_alignment_dir, ) # Set the LCS algorithm delay. frame_asr.lcs_delay = math.floor( ((total_buffer - chunk_len)) / model_stride_in_secs) else: raise ValueError( "Invalid choice of merge algorithm for transducer buffered inference." ) hyps, refs, wer = get_wer_feat( mfst=args.test_manifest, asr=frame_asr, tokens_per_chunk=tokens_per_chunk, delay=mid_delay, model_stride_in_secs=model_stride_in_secs, batch_size=args.batch_size, ) logging.info( f"WER is {round(wer, 4)} when decoded with a delay of {round(mid_delay*model_stride_in_secs, 2)}s" ) if args.output_path is not None: fname = (os.path.splitext(os.path.basename(args.asr_model))[0] + "_" + os.path.splitext(os.path.basename(args.test_manifest))[0] + "_" + str(args.chunk_len_in_secs) + "_" + str(int(total_buffer * 1000)) + "_" + args.merge_algo + ".json") hyp_json = os.path.join(args.output_path, fname) os.makedirs(args.output_path, exist_ok=True) with open(hyp_json, "w") as out_f: for i, hyp in enumerate(hyps): record = { "pred_text": hyp, "text": refs[i], "wer": round( word_error_rate(hypotheses=[hyp], references=[refs[i]]) * 100, 2), } out_f.write(json.dumps(record) + '\n')
def main(): parser = ArgumentParser() """Training arguments""" parser.add_argument("--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'") parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=8) parser.add_argument( "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English.") parser.add_argument("--shuffle", action='store_true', help="Shuffle test data.") """Calibration arguments""" parser.add_argument("--load", type=str, default=None, help="load path for the synthetic data") parser.add_argument( "--percentile", type=float, default=None, help="Max/min percentile for outlier handling. e.g., 99.9") """Quantization arguments""" parser.add_argument("--weight_bit", type=int, default=8, help="quantization bit for weights") parser.add_argument("--act_bit", type=int, default=8, help="quantization bit for activations") parser.add_argument("--dynamic", action='store_true', help="Dynamic quantization mode.") parser.add_argument("--no_quant", action='store_true', help="No quantization mode.") """Debugging arguments""" parser.add_argument("--eval_early_stop", type=int, default=None, help="early stop for debugging") parser.add_argument("--calib_early_stop", type=int, default=None, help="early stop calibration") args = parser.parse_args() torch.set_grad_enabled(False) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.normalize_text, 'shuffle': args.shuffle, }) if args.load is not None: print('Data loaded from %s' % args.load) with open(args.load, 'rb') as f: distilled_data = pickle.load(f) synthetic_batch_size, _, synthetic_seqlen = distilled_data[0].shape else: assert args.dynamic, \ "synthetic data must be loaded unless running with the dynamic quantization mode" ############################## Calibration ##################################### torch.set_grad_enabled(False) # disable backward graph generation asr_model.eval() # evaluation mode asr_model.set_quant_bit(args.weight_bit, mode='weight') asr_model.set_quant_bit(args.act_bit, mode='act') # set percentile if args.percentile is not None: qm.set_percentile(asr_model, args.percentile) if args.no_quant: asr_model.set_quant_mode('none') else: asr_model.encoder.bn_folding() # BN folding # if not dynamic quantization, calibrate min/max/range for the activations using synthetic data # if dynamic, we can skip calibration if not args.dynamic: print('Calibrating...') qm.calibrate(asr_model) length = torch.tensor([synthetic_seqlen] * synthetic_batch_size).cuda() for batch_idx, inputs in enumerate(distilled_data): if args.calib_early_stop is not None and batch_idx == args.calib_early_stop: break inputs = inputs.cuda() encoded, encoded_len, encoded_scaling_factor = asr_model.encoder( audio_signal=inputs, length=length) log_probs = asr_model.decoder( encoder_output=encoded, encoder_output_scaling_factor=encoded_scaling_factor) ############################## Evaluation ##################################### print('Evaluating...') qm.evaluate(asr_model) qm.set_dynamic( asr_model, args.dynamic) # if dynamic quantization, this will be enabled labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary) hypotheses = [] references = [] progress_bar = tqdm(asr_model.test_dataloader()) for i, test_batch in enumerate(progress_bar): if i == args.eval_early_stop: break test_batch = [x.cuda().float() for x in test_batch] with autocast(): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=test_batch[0], input_signal_length=test_batch[1]) hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions) for batch_ind in range(greedy_predictions.shape[0]): reference = ''.join([ labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy() ]) references.append(reference) del test_batch wer_value = word_error_rate(hypotheses=hypotheses, references=references) print('WER:', wer_value)
print("\ttags=", " ".join(predicted_tags[i])) print("\tpred=", predictions[i]) print("\tsemiotic=", predicted_semiotic[i]) print("\tref=", references[i][-1]) # last reference is actual reference sentences_with_errors_on_digits += 1 elif ok_all: correct_sentences_disregarding_space += 1 elif args.print_other_errors: print("other error:") print("\tinput=", " ".join(inputs[i])) print("\ttags=", " ".join(predicted_tags[i])) print("\tpred=", predictions[i]) print("\tsemiotic=", predicted_semiotic[i]) print("\tref=", references[i][-1]) # last reference is actual reference wer = word_error_rate(refs_for_wer, preds_for_wer) print("WER: ", wer) print( "Sentence accuracy: ", correct_sentences_disregarding_space / (len(inputs) - len(skip_ids)), correct_sentences_disregarding_space, ) print( "digit errors: ", sentences_with_errors_on_digits / (len(inputs) - len(skip_ids)), sentences_with_errors_on_digits, ) print( "other errors: ", (len(inputs) - len(skip_ids) - correct_sentences_disregarding_space - sentences_with_errors_on_digits) / (len(inputs) - len(skip_ids)),