def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=256) parser.add_argument( "--dont_normalize_text", default=False, action='store_false', help="Turn off trasnscript normalization. Recommended for non-English.", ) parser.add_argument('--num_calib_batch', default=1, type=int, help="Number of batches for calibration.") parser.add_argument('--calibrator', type=str, choices=["max", "histogram"], default="max") parser.add_argument('--percentile', nargs='+', type=float, default=[99.9, 99.99, 99.999, 99.9999]) parser.add_argument("--amp", action="store_true", help="Use AMP in calibration.") parser.set_defaults(amp=False) args = parser.parse_args() torch.set_grad_enabled(False) # Initialize quantization quant_desc_input = QuantDescriptor(calib_method=args.calibrator) quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantConvTranspose2d.set_default_quant_desc_input( quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, override_config_path=asr_model_cfg) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, override_config_path=asr_model_cfg) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.dont_normalize_text, 'shuffle': True, }) asr_model.preprocessor.featurizer.dither = 0.0 asr_model.preprocessor.featurizer.pad_to = 0 if can_gpu: asr_model = asr_model.cuda() asr_model.eval() # Enable calibrators for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.disable_quant() module.enable_calib() else: module.disable() for i, test_batch in enumerate(asr_model.test_dataloader()): if can_gpu: test_batch = [x.cuda() for x in test_batch] if args.amp: with autocast(): _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) else: _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) if i >= args.num_calib_batch: break # Save calibrated model(s) model_name = args.asr_model.replace( ".nemo", "") if args.asr_model.endswith(".nemo") else args.asr_model if not args.calibrator == "histogram": compute_amax(asr_model, method="max") asr_model.save_to( F"{model_name}-max-{args.num_calib_batch*args.batch_size}.nemo") else: for percentile in args.percentile: print(F"{percentile} percentile calibration") compute_amax(asr_model, method="percentile") asr_model.save_to( F"{model_name}-percentile-{percentile}-{args.num_calib_batch*args.batch_size}.nemo" ) for method in ["mse", "entropy"]: print(F"{method} calibration") compute_amax(asr_model, method=method) asr_model.save_to( F"{model_name}-{method}-{args.num_calib_batch*args.batch_size}.nemo" )
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--wer_target", type=float, default=None, help="used by test") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument( "--dont_normalize_text", default=False, action='store_false', help="Turn off trasnscript normalization. Recommended for non-English.", ) parser.add_argument( "--use_cer", default=False, action='store_true', help="Use Character Error Rate as the evaluation metric") parser.add_argument('--sensitivity', action="store_true", help="Perform sensitivity analysis") parser.add_argument('--onnx', action="store_true", help="Export to ONNX") parser.add_argument('--quant-disable-keyword', type=str, nargs='+', help='disable quantizers by keyword') args = parser.parse_args() torch.set_grad_enabled(False) quant_modules.initialize() if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, override_config_path=asr_model_cfg) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, override_config_path=asr_model_cfg) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.dont_normalize_text, }) asr_model.preprocessor.featurizer.dither = 0.0 asr_model.preprocessor.featurizer.pad_to = 0 if can_gpu: asr_model = asr_model.cuda() asr_model.eval() if args.quant_disable_keyword: for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): for keyword in args.quant_disable_keyword: if keyword in name: logging.warning(F"Disable {name}") module.disable() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary, use_cer=args.use_cer) wer_quant = evaluate(asr_model, labels_map, wer) logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}') if args.sensitivity: if wer_quant < args.wer_tolerance: logging.info( "Tolerance is already met. Skip sensitivity analyasis.") return quant_layer_names = [] for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.disable() layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") if layer_name not in quant_layer_names: quant_layer_names.append(layer_name) logging.info(F"{len(quant_layer_names)} quantized layers found.") # Build sensitivity profile quant_layer_sensitivity = {} for i, quant_layer in enumerate(quant_layer_names): logging.info(F"Enable {quant_layer}") for name, module in asr_model.named_modules(): if isinstance( module, quant_nn.TensorQuantizer) and quant_layer in name: module.enable() logging.info(F"{name:40}: {module}") # Eval the model wer_value = evaluate(asr_model, labels_map, wer) logging.info(F"WER: {wer_value}") quant_layer_sensitivity[ quant_layer] = args.wer_tolerance - wer_value for name, module in asr_model.named_modules(): if isinstance( module, quant_nn.TensorQuantizer) and quant_layer in name: module.disable() logging.info(F"{name:40}: {module}") # Skip most sensitive layers until WER target is met for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.enable() quant_layer_sensitivity = collections.OrderedDict( sorted(quant_layer_sensitivity.items(), key=lambda x: x[1])) pprint(quant_layer_sensitivity) skipped_layers = [] for quant_layer, _ in quant_layer_sensitivity.items(): for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if quant_layer in name: logging.info(F"Disable {name}") if not quant_layer in skipped_layers: skipped_layers.append(quant_layer) module.disable() wer_value = evaluate(asr_model, labels_map, wer) if wer_value <= args.wer_tolerance: logging.info( F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers." ) print(skipped_layers) export_onnx(args, asr_model) return raise ValueError( f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!" ) export_onnx(args, asr_model)