def set_asr_model(self, asr_model): """ Setup the parameters for the given ASR model Currently, the following models are supported: stt_en_conformer_ctc_large stt_en_conformer_ctc_medium stt_en_conformer_ctc_small QuartzNet15x5Base-En """ if 'QuartzNet' in asr_model: self.run_ASR = self.run_ASR_QuartzNet_CTC asr_model = EncDecCTCModel.from_pretrained(model_name=asr_model, strict=False) self.params['offset'] = -0.18 self.model_stride_in_secs = 0.02 self.asr_delay_sec = -1 * self.params['offset'] elif 'conformer_ctc' in asr_model: self.run_ASR = self.run_ASR_BPE_CTC asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model, strict=False) self.model_stride_in_secs = 0.04 self.asr_delay_sec = 0.0 self.params['offset'] = 0 self.chunk_len_in_sec = 1.6 self.total_buffer_in_secs = 4 elif 'citrinet' in asr_model: self.run_ASR = self.run_ASR_BPE_CTC asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model, strict=False) self.model_stride_in_secs = 0.08 self.asr_delay_sec = 0.0 self.params['offset'] = 0 self.chunk_len_in_sec = 1.6 self.total_buffer_in_secs = 4 elif 'conformer_transducer' in asr_model or 'contextnet' in asr_model: self.run_ASR = self.run_ASR_BPE_RNNT asr_model = EncDecRNNTBPEModel.from_pretrained( model_name=asr_model, strict=False) self.model_stride_in_secs = 0.04 self.asr_delay_sec = 0.0 self.params['offset'] = 0 self.chunk_len_in_sec = 1.6 self.total_buffer_in_secs = 4 else: raise ValueError(f"ASR model name not found: {asr_model}") self.params['time_stride'] = self.model_stride_in_secs self.asr_batch_size = 16 asr_model.eval() self.audio_file_list = [ value['audio_filepath'] for _, value in self.AUDIO_RTTM_MAP.items() ] return asr_model
def test_EncDecCTCModelBPE_v2(self): # TODO: Switch to using named configs because here we don't really care about weights cn = EncDecCTCModelBPE.from_pretrained( model_name="stt_en_conformer_ctc_small") self.__test_restore_elsewhere( model=cn, attr_for_eq_check=set(["decoder._feat_in", "decoder._num_classes"]))
def test_EncDecCTCModelBPE(self): # TODO: Switch to using named configs because here we don't really care about weights cn = EncDecCTCModelBPE.from_pretrained( model_name="ContextNet-192-WPE-1024-8x-Stride") self.__test_restore_elsewhere( model=cn, attr_for_eq_check=set(["decoder._feat_in", "decoder._num_classes"]))
def main(cfg): if cfg.n_gpus > 0: cfg.model.train_ds.batch_size //= cfg.n_gpus logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg, resolve=True)}') pl.utilities.seed.seed_everything(cfg.seed) trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) if "tokenizer" in cfg.model: asr_model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer) else: asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer) # Initialize the weights of the model from another model, if provided via config asr_model.maybe_init_from_pretrained_checkpoint(cfg) trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: gpu = 1 if cfg.trainer.gpus != 0 else 0 test_trainer = pl.Trainer( gpus=gpu, precision=trainer.precision, amp_level=trainer.accelerator_connector.amp_level, amp_backend=cfg.trainer.get("amp_backend", "native"), ) if asr_model.prepare_test(test_trainer): test_trainer.test(asr_model)
def set_asr_model(self, ASR_model_name): if 'QuartzNet' in ASR_model_name: self.run_ASR = self.run_ASR_QuartzNet_CTC asr_model = EncDecCTCModel.from_pretrained( model_name=ASR_model_name, strict=False) elif 'conformer' in ASR_model_name: self.run_ASR = self.run_ASR_Conformer_CTC _ = EncDecCTCModelBPE.from_pretrained(model_name=ASR_model_name, strict=False) raise NotImplementedError # This option has not been implemented yet. elif 'citrinet' in ASR_model_name: raise NotImplementedError else: raise ValueError( f"ASR model name not found: {self.params['ASR_model_name']}") return asr_model
def main(cfg): name_prefix, checkpoint_paths, save_ckpt_only = process_config(cfg) if not save_ckpt_only: trainer = pl.Trainer(**cfg.trainer) # <<< Change model class here ! >>> # Model architecture which will contain the averaged checkpoints # Change the model constructor to the one you would like (if needed) model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer) """ < Checkpoint Averaging Logic > """ # load state dicts n = len(checkpoint_paths) avg_state = None logging.info(f"Averaging {n} checkpoints ...") for ix, path in enumerate(checkpoint_paths): checkpoint = torch.load(path, map_location='cpu') if 'state_dict' in checkpoint: checkpoint = checkpoint['state_dict'] if ix == 0: # Initial state avg_state = checkpoint logging.info( f"Initialized average state dict with checkpoint : {path}") else: # Accumulated state for k in avg_state: avg_state[k] = avg_state[k] + checkpoint[k] logging.info( f"Updated average state dict with state from checkpoint : {path}" ) for k in avg_state: if str(avg_state[k].dtype).startswith("torch.int"): # For int type, not averaged, but only accumulated. # e.g. BatchNorm.num_batches_tracked pass else: avg_state[k] = avg_state[k] / n # Save model if save_ckpt_only: ckpt_name = name_prefix + '-averaged.ckpt' torch.save(avg_state, ckpt_name) logging.info(f"Averaged pytorch checkpoint saved as : {ckpt_name}") else: # Set model state logging.info("Loading averaged state dict in provided model") model.load_state_dict(avg_state, strict=True) ckpt_name = name_prefix + '-averaged.nemo' model.save_to(ckpt_name) logging.info(f"Averaged model saved as : {ckpt_name}")
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=256) parser.add_argument( "--dont_normalize_text", default=False, action='store_false', help="Turn off trasnscript normalization. Recommended for non-English.", ) parser.add_argument('--num_calib_batch', default=1, type=int, help="Number of batches for calibration.") parser.add_argument('--calibrator', type=str, choices=["max", "histogram"], default="max") parser.add_argument('--percentile', nargs='+', type=float, default=[99.9, 99.99, 99.999, 99.9999]) parser.add_argument("--amp", action="store_true", help="Use AMP in calibration.") parser.set_defaults(amp=False) args = parser.parse_args() torch.set_grad_enabled(False) # Initialize quantization quant_desc_input = QuantDescriptor(calib_method=args.calibrator) quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantConvTranspose2d.set_default_quant_desc_input( quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, override_config_path=asr_model_cfg) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, override_config_path=asr_model_cfg) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.dont_normalize_text, 'shuffle': True, }) asr_model.preprocessor.featurizer.dither = 0.0 asr_model.preprocessor.featurizer.pad_to = 0 if can_gpu: asr_model = asr_model.cuda() asr_model.eval() # Enable calibrators for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.disable_quant() module.enable_calib() else: module.disable() for i, test_batch in enumerate(asr_model.test_dataloader()): if can_gpu: test_batch = [x.cuda() for x in test_batch] if args.amp: with autocast(): _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) else: _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) if i >= args.num_calib_batch: break # Save calibrated model(s) model_name = args.asr_model.replace( ".nemo", "") if args.asr_model.endswith(".nemo") else args.asr_model if not args.calibrator == "histogram": compute_amax(asr_model, method="max") asr_model.save_to( F"{model_name}-max-{args.num_calib_batch*args.batch_size}.nemo") else: for percentile in args.percentile: print(F"{percentile} percentile calibration") compute_amax(asr_model, method="percentile") asr_model.save_to( F"{model_name}-percentile-{percentile}-{args.num_calib_batch*args.batch_size}.nemo" ) for method in ["mse", "entropy"]: print(F"{method} calibration") compute_amax(asr_model, method=method) asr_model.save_to( F"{model_name}-{method}-{args.num_calib_batch*args.batch_size}.nemo" )
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--wer_target", type=float, default=None, help="used by test") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument( "--dont_normalize_text", default=False, action='store_false', help="Turn off trasnscript normalization. Recommended for non-English.", ) parser.add_argument( "--use_cer", default=False, action='store_true', help="Use Character Error Rate as the evaluation metric") parser.add_argument('--sensitivity', action="store_true", help="Perform sensitivity analysis") parser.add_argument('--onnx', action="store_true", help="Export to ONNX") parser.add_argument('--quant-disable-keyword', type=str, nargs='+', help='disable quantizers by keyword') args = parser.parse_args() torch.set_grad_enabled(False) quant_modules.initialize() if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, override_config_path=asr_model_cfg) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, override_config_path=asr_model_cfg) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.dont_normalize_text, }) asr_model.preprocessor.featurizer.dither = 0.0 asr_model.preprocessor.featurizer.pad_to = 0 if can_gpu: asr_model = asr_model.cuda() asr_model.eval() if args.quant_disable_keyword: for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): for keyword in args.quant_disable_keyword: if keyword in name: logging.warning(F"Disable {name}") module.disable() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary, use_cer=args.use_cer) wer_quant = evaluate(asr_model, labels_map, wer) logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}') if args.sensitivity: if wer_quant < args.wer_tolerance: logging.info( "Tolerance is already met. Skip sensitivity analyasis.") return quant_layer_names = [] for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.disable() layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") if layer_name not in quant_layer_names: quant_layer_names.append(layer_name) logging.info(F"{len(quant_layer_names)} quantized layers found.") # Build sensitivity profile quant_layer_sensitivity = {} for i, quant_layer in enumerate(quant_layer_names): logging.info(F"Enable {quant_layer}") for name, module in asr_model.named_modules(): if isinstance( module, quant_nn.TensorQuantizer) and quant_layer in name: module.enable() logging.info(F"{name:40}: {module}") # Eval the model wer_value = evaluate(asr_model, labels_map, wer) logging.info(F"WER: {wer_value}") quant_layer_sensitivity[ quant_layer] = args.wer_tolerance - wer_value for name, module in asr_model.named_modules(): if isinstance( module, quant_nn.TensorQuantizer) and quant_layer in name: module.disable() logging.info(F"{name:40}: {module}") # Skip most sensitive layers until WER target is met for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.enable() quant_layer_sensitivity = collections.OrderedDict( sorted(quant_layer_sensitivity.items(), key=lambda x: x[1])) pprint(quant_layer_sensitivity) skipped_layers = [] for quant_layer, _ in quant_layer_sensitivity.items(): for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if quant_layer in name: logging.info(F"Disable {name}") if not quant_layer in skipped_layers: skipped_layers.append(quant_layer) module.disable() wer_value = evaluate(asr_model, labels_map, wer) if wer_value <= args.wer_tolerance: logging.info( F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers." ) print(skipped_layers) export_onnx(args, asr_model) return raise ValueError( f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!" ) export_onnx(args, asr_model)