def test_EncDecCTCModel(self): # TODO: Switch to using named configs because here we don't really care about weights qn = EncDecCTCModel.from_pretrained(model_name="QuartzNet15x5Base-En") self.__test_restore_elsewhere( model=qn, attr_for_eq_check=set(["decoder._feat_in", "decoder._num_classes"]))
def set_asr_model(self, asr_model): """ Setup the parameters for the given ASR model Currently, the following models are supported: stt_en_conformer_ctc_large stt_en_conformer_ctc_medium stt_en_conformer_ctc_small QuartzNet15x5Base-En """ if 'QuartzNet' in asr_model: self.run_ASR = self.run_ASR_QuartzNet_CTC asr_model = EncDecCTCModel.from_pretrained(model_name=asr_model, strict=False) self.params['offset'] = -0.18 self.model_stride_in_secs = 0.02 self.asr_delay_sec = -1 * self.params['offset'] elif 'conformer_ctc' in asr_model: self.run_ASR = self.run_ASR_BPE_CTC asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model, strict=False) self.model_stride_in_secs = 0.04 self.asr_delay_sec = 0.0 self.params['offset'] = 0 self.chunk_len_in_sec = 1.6 self.total_buffer_in_secs = 4 elif 'citrinet' in asr_model: self.run_ASR = self.run_ASR_BPE_CTC asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model, strict=False) self.model_stride_in_secs = 0.08 self.asr_delay_sec = 0.0 self.params['offset'] = 0 self.chunk_len_in_sec = 1.6 self.total_buffer_in_secs = 4 elif 'conformer_transducer' in asr_model or 'contextnet' in asr_model: self.run_ASR = self.run_ASR_BPE_RNNT asr_model = EncDecRNNTBPEModel.from_pretrained( model_name=asr_model, strict=False) self.model_stride_in_secs = 0.04 self.asr_delay_sec = 0.0 self.params['offset'] = 0 self.chunk_len_in_sec = 1.6 self.total_buffer_in_secs = 4 else: raise ValueError(f"ASR model name not found: {asr_model}") self.params['time_stride'] = self.model_stride_in_secs self.asr_batch_size = 16 asr_model.eval() self.audio_file_list = [ value['audio_filepath'] for _, value in self.AUDIO_RTTM_MAP.items() ] return asr_model
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument( "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English." ) args = parser.parse_args() torch.set_grad_enabled(False) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.normalize_text, } ) if can_gpu: asr_model = asr_model.cuda() asr_model.eval() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary) hypotheses = [] references = [] for test_batch in asr_model.test_dataloader(): if can_gpu: test_batch = [x.cuda() for x in test_batch] with autocast(): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=test_batch[0], input_signal_length=test_batch[1] ) hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions) for batch_ind in range(greedy_predictions.shape[0]): reference = ''.join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()]) references.append(reference) del test_batch wer_value = word_error_rate(hypotheses=hypotheses, references=references) if wer_value > args.wer_tolerance: raise ValueError(f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}") logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
def set_asr_model(self, ASR_model_name): if 'QuartzNet' in ASR_model_name: self.run_ASR = self.run_ASR_QuartzNet_CTC asr_model = EncDecCTCModel.from_pretrained( model_name=ASR_model_name, strict=False) elif 'conformer' in ASR_model_name: self.run_ASR = self.run_ASR_Conformer_CTC _ = EncDecCTCModelBPE.from_pretrained(model_name=ASR_model_name, strict=False) raise NotImplementedError # This option has not been implemented yet. elif 'citrinet' in ASR_model_name: raise NotImplementedError else: raise ValueError( f"ASR model name not found: {self.params['ASR_model_name']}") return asr_model
def batch_inference(args: argparse.Namespace): torch.set_grad_enabled(False) if args.asr_model.endswith(".nemo"): print(f"Using local ASR model from {args.asr_model}") asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model) else: print(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) manifest = prepare_manifest(args.corpora_dir, args.limit) asr_model.setup_test_data( test_data_config={ "sample_rate": 16000, "manifest_filepath": manifest, "labels": asr_model.decoder.vocabulary, "batch_size": args.batch_size, "normalize_transcripts": args.normalize_text, }) refs_hyps = list(tqdm(generate_ref_hyps(asr_model, args.search, args.arpa))) references, hypotheses = [list(k) for k in zip(*refs_hyps)] os.makedirs(args.results_dir, exist_ok=True) data_io.write_lines(f"{args.results_dir}/refs.txt.gz", references) data_io.write_lines(f"{args.results_dir}/hyps.txt.gz", hypotheses) wer_value = word_error_rate(hypotheses=hypotheses, references=references) sys.stdout.flush() stats = { "wer": wer_value, "args": args.__dict__, } data_io.write_json(f"{args.results_dir}/stats.txt", stats) print(f"Got WER of {wer_value}") return stats
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", choices=[ x.pretrained_model_name for x in EncDecCTCModel.list_available_models() ], ) parser.add_argument( "--tts_model_spec", type=str, default="Tacotron2-22050Hz", choices=[ x.pretrained_model_name for x in SpectrogramGenerator.list_available_models() ], ) parser.add_argument( "--tts_model_vocoder", type=str, default="WaveGlow-22050Hz", choices=[ x.pretrained_model_name for x in Vocoder.list_available_models() ], ) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument("--trim", action="store_true") parser.add_argument("--debug", action="store_true") args = parser.parse_args() torch.set_grad_enabled(False) if args.debug: logging.set_verbosity(logging.DEBUG) logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) logging.info( f"Using NGC cloud TTS Spectrogram Generator model {args.tts_model_spec}" ) tts_model_spec = SpectrogramGenerator.from_pretrained( model_name=args.tts_model_spec) logging.info(f"Using NGC cloud TTS Vocoder model {args.tts_model_vocoder}") tts_model_vocoder = Vocoder.from_pretrained( model_name=args.tts_model_vocoder) models = [asr_model, tts_model_spec, tts_model_vocoder] if torch.cuda.is_available(): for i, m in enumerate(models): models[i] = m.cuda() for m in models: m.eval() asr_model, tts_model_spec, tts_model_vocoder = models parser = parsers.make_parser( labels=asr_model.decoder.vocabulary, name="en", unk_id=-1, blank_id=-1, do_normalize=True, ) labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) tts_input = [] asr_references = [] longest_tts_input = 0 for test_str in LIST_OF_TEST_STRINGS: tts_parsed_input = tts_model_spec.parse(test_str) if len(tts_parsed_input[0]) > longest_tts_input: longest_tts_input = len(tts_parsed_input[0]) tts_input.append(tts_parsed_input.squeeze()) asr_parsed = parser(test_str) asr_parsed = ''.join([labels_map[c] for c in asr_parsed]) asr_references.append(asr_parsed) # Pad TTS Inputs for i, text in enumerate(tts_input): pad = (0, longest_tts_input - len(text)) tts_input[i] = torch.nn.functional.pad(text, pad, value=68) logging.debug(tts_input) # Do TTS tts_input = torch.stack(tts_input) if torch.cuda.is_available(): tts_input = tts_input.cuda() specs = tts_model_spec.generate_spectrogram(tokens=tts_input) audio = [] step = ceil(len(specs) / 4) for i in range(4): audio.append( tts_model_vocoder.convert_spectrogram_to_audio( spec=specs[i * step:i * step + step])) audio = [item for sublist in audio for item in sublist] audio_file_paths = [] # Save audio logging.debug(f"args.trim: {args.trim}") for i, aud in enumerate(audio): aud = aud.cpu().numpy() if args.trim: aud = librosa.effects.trim(aud, top_db=40)[0] librosa.output.write_wav(f"{i}.wav", aud, sr=22050) audio_file_paths.append(str(Path(f"{i}.wav"))) # Do ASR hypotheses = asr_model.transcribe(audio_file_paths) for i, _ in enumerate(hypotheses): logging.debug(f"{i}") logging.debug(f"ref:'{asr_references[i]}'") logging.debug(f"hyp:'{hypotheses[i]}'") wer_value = word_error_rate(hypotheses=hypotheses, references=asr_references) if wer_value > args.wer_tolerance: raise ValueError( f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}") logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=False, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument( "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English.") parser.add_argument( "--sclite_fmt", default="trn", type=str, help="sclite output format. Only trn and ctm are supported") parser.add_argument("--out_dir", type=str, required=True, help="Destination dir for output files") parser.add_argument("--sctk_dir", type=str, required=False, default="", help="Path to sctk root dir") parser.add_argument("--glm", type=str, required=False, default="", help="Path to glm file") parser.add_argument("--ref_stm", type=str, required=False, default="", help="Path to glm file") args = parser.parse_args() torch.set_grad_enabled(False) if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) use_sctk = os.path.exists(args.sctk_dir) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.normalize_text, }) if can_gpu: asr_model = asr_model.cuda() asr_model.eval() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary) hypotheses = [] references = [] all_log_probs = [] for test_batch in asr_model.test_dataloader(): if can_gpu: test_batch = [x.cuda() for x in test_batch] with autocast(): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=test_batch[0], input_signal_length=test_batch[1]) for r in log_probs.cpu().numpy(): all_log_probs.append(r) hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions) for batch_ind in range(greedy_predictions.shape[0]): reference = ''.join([ labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy() ]) references.append(reference) del test_batch info_list = get_utt_info(args.dataset) hypfile = os.path.join(args.out_dir, "hyp.trn") reffile = os.path.join(args.out_dir, "ref.trn") with open(hypfile, "w") as hyp_f, open(reffile, "w") as ref_f: for i in range(len(hypotheses)): utt_id = os.path.splitext( os.path.basename(info_list[i]['audio_filepath']))[0] # rfilter in sctk likes each transcript to have a space at the beginning hyp_f.write(" " + hypotheses[i] + " (" + utt_id + ")" + "\n") ref_f.write(" " + references[i] + " (" + utt_id + ")" + "\n") if use_sctk: score_with_sctk(args.sctk_dir, reffile, hypfile, args.out_dir, glm=args.glm, fmt="trn")
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--wer_target", type=float, default=None, help="used by test") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument( "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English." ) parser.add_argument('--sensitivity', action="store_true", help="Perform sensitivity analysis") parser.add_argument('--onnx', action="store_true", help="Export to ONNX") args = parser.parse_args() torch.set_grad_enabled(False) quant_modules.initialize() if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.normalize_text, } ) if can_gpu: asr_model = asr_model.cuda() asr_model.eval() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary) wer_quant = evaluate(asr_model, labels_map, wer) logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}') if args.sensitivity: if wer_quant < args.wer_tolerance: logging.info("Tolerance is already met. Skip sensitivity analyasis.") return quant_layer_names = [] for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.disable() layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") if layer_name not in quant_layer_names: quant_layer_names.append(layer_name) logging.info(F"{len(quant_layer_names)} quantized layers found.") # Build sensitivity profile quant_layer_sensitivity = {} for i, quant_layer in enumerate(quant_layer_names): logging.info(F"Enable {quant_layer}") for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name: module.enable() logging.info(F"{name:40}: {module}") # Eval the model wer_value = evaluate(asr_model, labels_map, wer) logging.info(F"WER: {wer_value}") quant_layer_sensitivity[quant_layer] = args.wer_tolerance - wer_value for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name: module.disable() logging.info(F"{name:40}: {module}") # Skip most sensitive layers until WER target is met for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.enable() quant_layer_sensitivity = collections.OrderedDict(sorted(quant_layer_sensitivity.items(), key=lambda x: x[1])) pprint(quant_layer_sensitivity) skipped_layers = [] for quant_layer, _ in quant_layer_sensitivity.items(): for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if quant_layer in name: logging.info(F"Disable {name}") if not quant_layer in skipped_layers: skipped_layers.append(quant_layer) module.disable() wer_value = evaluate(asr_model, labels_map, wer) if wer_value <= args.wer_tolerance: logging.info( F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers." ) print(skipped_layers) return raise ValueError(f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!") if args.onnx: if args.asr_model.endswith("nemo"): onnx_name = args.asr_model.replace(".nemo", ".onnx") else: onnx_name = args.asr_model logging.info("Export to ", onnx_name) quant_nn.TensorQuantizer.use_fb_fake_quant = True asr_model.export(onnx_name, onnx_opset_version=13) quant_nn.TensorQuantizer.use_fb_fake_quant = False
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=256) parser.add_argument( "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English.") parser.add_argument('--num_calib_batch', default=1, type=int, help="Number of batches for calibration.") parser.add_argument('--calibrator', type=str, choices=["max", "histogram"], default="max") parser.add_argument('--percentile', nargs='+', type=float, default=[99.9, 99.99, 99.999, 99.9999]) parser.add_argument("--amp", action="store_true", help="Use AMP in calibration.") parser.set_defaults(amp=False) args = parser.parse_args() torch.set_grad_enabled(False) # Initialize quantization quant_desc_input = QuantDescriptor(calib_method=args.calibrator) quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantConvTranspose2d.set_default_quant_desc_input( quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model_cfg = EncDecCTCModel.restore_from( restore_path=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModel.restore_from( restore_path=args.asr_model, override_config_path=asr_model_cfg) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model_cfg = EncDecCTCModel.from_pretrained( model_name=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModel.from_pretrained( model_name=args.asr_model, override_config_path=asr_model_cfg) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.normalize_text, 'shuffle': True, }) if can_gpu: asr_model = asr_model.cuda() asr_model.eval() # Enable calibrators for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.disable_quant() module.enable_calib() else: module.disable() for i, test_batch in enumerate(asr_model.test_dataloader()): if can_gpu: test_batch = [x.cuda() for x in test_batch] if args.amp: with autocast(): _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) else: _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) if i >= args.num_calib_batch: break # Save calibrated model(s) model_name = args.asr_model.replace( ".nemo", "") if args.asr_model.endswith(".nemo") else args.asr_model if not args.calibrator == "histogram": compute_amax(asr_model, method="max") asr_model.save_to( F"{model_name}-max-{args.num_calib_batch*args.batch_size}.nemo") else: for percentile in args.percentile: print(F"{percentile} percentile calibration") compute_amax(asr_model, method="percentile") asr_model.save_to( F"{model_name}-percentile-{percentile}-{args.num_calib_batch*args.batch_size}.nemo" ) for method in ["mse", "entropy"]: print(F"{method} calibration") compute_amax(asr_model, method=method) asr_model.save_to( F"{model_name}-{method}-{args.num_calib_batch*args.batch_size}.nemo" )
def main(): parser = ArgumentParser() """Training arguments""" parser.add_argument("--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'") parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=8) parser.add_argument( "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English.") parser.add_argument("--shuffle", action='store_true', help="Shuffle test data.") """Calibration arguments""" parser.add_argument("--load", type=str, default=None, help="load path for the synthetic data") parser.add_argument( "--percentile", type=float, default=None, help="Max/min percentile for outlier handling. e.g., 99.9") """Quantization arguments""" parser.add_argument("--weight_bit", type=int, default=8, help="quantization bit for weights") parser.add_argument("--act_bit", type=int, default=8, help="quantization bit for activations") parser.add_argument("--dynamic", action='store_true', help="Dynamic quantization mode.") parser.add_argument("--no_quant", action='store_true', help="No quantization mode.") """Debugging arguments""" parser.add_argument("--eval_early_stop", type=int, default=None, help="early stop for debugging") parser.add_argument("--calib_early_stop", type=int, default=None, help="early stop calibration") args = parser.parse_args() torch.set_grad_enabled(False) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.normalize_text, 'shuffle': args.shuffle, }) if args.load is not None: print('Data loaded from %s' % args.load) with open(args.load, 'rb') as f: distilled_data = pickle.load(f) synthetic_batch_size, _, synthetic_seqlen = distilled_data[0].shape else: assert args.dynamic, \ "synthetic data must be loaded unless running with the dynamic quantization mode" ############################## Calibration ##################################### torch.set_grad_enabled(False) # disable backward graph generation asr_model.eval() # evaluation mode asr_model.set_quant_bit(args.weight_bit, mode='weight') asr_model.set_quant_bit(args.act_bit, mode='act') # set percentile if args.percentile is not None: qm.set_percentile(asr_model, args.percentile) if args.no_quant: asr_model.set_quant_mode('none') else: asr_model.encoder.bn_folding() # BN folding # if not dynamic quantization, calibrate min/max/range for the activations using synthetic data # if dynamic, we can skip calibration if not args.dynamic: print('Calibrating...') qm.calibrate(asr_model) length = torch.tensor([synthetic_seqlen] * synthetic_batch_size).cuda() for batch_idx, inputs in enumerate(distilled_data): if args.calib_early_stop is not None and batch_idx == args.calib_early_stop: break inputs = inputs.cuda() encoded, encoded_len, encoded_scaling_factor = asr_model.encoder( audio_signal=inputs, length=length) log_probs = asr_model.decoder( encoder_output=encoded, encoder_output_scaling_factor=encoded_scaling_factor) ############################## Evaluation ##################################### print('Evaluating...') qm.evaluate(asr_model) qm.set_dynamic( asr_model, args.dynamic) # if dynamic quantization, this will be enabled labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary) hypotheses = [] references = [] progress_bar = tqdm(asr_model.test_dataloader()) for i, test_batch in enumerate(progress_bar): if i == args.eval_early_stop: break test_batch = [x.cuda().float() for x in test_batch] with autocast(): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=test_batch[0], input_signal_length=test_batch[1]) hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions) for batch_ind in range(greedy_predictions.shape[0]): reference = ''.join([ labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy() ]) references.append(reference) del test_batch wer_value = word_error_rate(hypotheses=hypotheses, references=references) print('WER:', wer_value)
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument( "--asr_onnx", type=str, default="./QuartzNet15x5Base-En-max-32.onnx", help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument( "--dont_normalize_text", default=False, action='store_false', help="Turn off trasnscript normalization. Recommended for non-English.", ) parser.add_argument( "--use_cer", default=False, action='store_true', help="Use Character Error Rate as the evaluation metric") parser.add_argument('--qat', action="store_true", help="Use onnx file exported from QAT tools") args = parser.parse_args() torch.set_grad_enabled(False) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model_cfg = EncDecCTCModel.restore_from( restore_path=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModel.restore_from( restore_path=args.asr_model, override_config_path=asr_model_cfg) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model_cfg = EncDecCTCModel.from_pretrained( model_name=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModel.from_pretrained( model_name=args.asr_model, override_config_path=asr_model_cfg) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.dont_normalize_text, }) asr_model.preprocessor.featurizer.dither = 0.0 asr_model.preprocessor.featurizer.pad_to = 0 if can_gpu: asr_model = asr_model.cuda() asr_model.eval() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary, use_cer=args.use_cer) wer_result = evaluate(asr_model, args.asr_onnx, labels_map, wer, args.qat) logging.info(f'Got WER of {wer_result}.')
def main(): parser = ArgumentParser() parser.add_argument("--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'") parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--num_batch", type=int, default=50, help="number of batches of the synthetic data") parser.add_argument("--batch_size", type=int, default=8, help="batch size of the synthetic data") parser.add_argument("--seqlen", type=int, default=500, help="sequence length of the synthetic data") parser.add_argument( "--train_iter", type=int, default=200, help="training iterations for the synthetic data generation") parser.add_argument("--dump_path", type=str, default=None, help="path to dump the synthetic data") parser.add_argument( "--dump_prefix", type=str, default='syn', help="prefix for the filename of the dumped synthetic data") parser.add_argument("--lr", type=float, default=0.01, help="Learning rate for the synthetic data generation") args = parser.parse_args() torch.set_grad_enabled(False) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") teacher_model = EncDecCTCModel.restore_from( restore_path=args.asr_model) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") teacher_model = EncDecCTCModel.from_pretrained( model_name=args.asr_model) teacher_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': teacher_model.decoder.vocabulary, 'batch_size': 8, 'normalize_transcripts': True, 'shuffle': True, }) ############################## Distillation ##################################### teacher_model.set_quant_mode( 'none') # distable quantization mode for the teacher model torch.set_grad_enabled(True) # enable backward graph generation print("Num batches: %d, Batch size: %d, Training iterations: %d, Learning rate: %.3f " \ % (args.num_batch, args.batch_size, args.train_iter, args.lr)) print('Synthesizing...') synthetic_data = get_synthetic_data(teacher_model.encoder, teacher_model.decoder, batch_size=args.batch_size, dim=64, seqlen=args.seqlen, num_batch=args.num_batch, train_iter=args.train_iter, lr=args.lr) file_name = '%s_nb%d_iter%d_lr%.3f.pkl' % \ (args.dump_prefix, args.num_batch, args.train_iter, args.lr) if args.dump_path is not None: if not os.path.exists(args.dump_path): os.makedirs(args.dump_path) file_name = os.path.join(args.dump_path, file_name) print('Synthetic data dumped as ', file_name) with open(file_name, 'wb') as f: pickle.dump([x.cpu() for x in synthetic_data], f)