def test_initialize_deactivate(self): no_replace_list = ["Linear"] custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)] quant_modules.initialize(no_replace_list, custom_quant_modules) assert (type(quant_nn.QuantLinear(16, 256, 3)) == type( torch.nn.Linear(16, 256, 3))) assert (type(quant_nn.QuantConv2d(16, 256, 3)) == type( torch.nn.Conv2d(16, 256, 3))) quant_modules.deactivate()
def test_quant_module_replacement(self): """test monkey patching of modules with their quantized versions""" lenet = LeNet() qlenet = QuantLeNet() mod_list = [type(mod) for name, mod in lenet.named_modules()] mod_list = mod_list[1:] qmod_list = [type(mod) for name, mod in qlenet.named_modules()] qmod_list = qmod_list[1:] # Before any monkey patching, the networks should be different assert(mod_list != qmod_list) # Monkey patch the modules no_replace_list = ["Linear"] custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)] quant_modules.initialize(no_replace_list, custom_quant_modules) lenet = LeNet() qlenet = QuantLeNet() mod_list = [type(mod) for name, mod in lenet.named_modules()] mod_list = mod_list[1:] qmod_list = [type(mod) for name, mod in qlenet.named_modules()] qmod_list = qmod_list[1:] # After monkey patching, the networks should be same assert(mod_list == qmod_list) # Reverse monkey patching quant_modules.deactivate() lenet = LeNet() qlenet = QuantLeNet() mod_list = [type(mod) for name, mod in lenet.named_modules()] mod_list = mod_list[1:] qmod_list = [type(mod) for name, mod in qlenet.named_modules()] qmod_list = qmod_list[1:] # After reversing monkey patching, the networks should again be different assert(mod_list != qmod_list)
def test_asp(self): """test Sparsity (ASP) and QAT toolkits together""" try: from apex.contrib.sparsity import ASP except ImportError: pytest.skip("ASP is not available.") quant_modules.initialize() model = LeNet() quant_modules.deactivate() optimizer = optim.SGD(model.parameters(), lr=0.01) ASP.init_model_for_pruning( model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d, quant_nn.modules.quant_linear.QuantLinear], allow_recompute_mask=False, custom_layer_dict={ quant_nn.QuantConv1d: ['weight'], quant_nn.QuantConv2d: ['weight'], quant_nn.QuantConv3d: ['weight'], quant_nn.QuantConvTranspose1d: ['weight'], quant_nn.QuantConvTranspose2d: ['weight'], quant_nn.QuantConvTranspose3d: ['weight'], quant_nn.QuantLinear: ['weight'] }) ASP.init_optimizer_for_pruning(optimizer) ASP.compute_sparse_masks() model = model.to('cuda') output = model(torch.empty(16, 1, 28, 28).to('cuda')) optimizer.zero_grad() loss = F.nll_loss(output, torch.randint(10, (16,), dtype=torch.int64)) loss.backward() optimizer.step()
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--wer_target", type=float, default=None, help="used by test") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument( "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English." ) parser.add_argument('--sensitivity', action="store_true", help="Perform sensitivity analysis") parser.add_argument('--onnx', action="store_true", help="Export to ONNX") args = parser.parse_args() torch.set_grad_enabled(False) quant_modules.initialize() if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.normalize_text, } ) if can_gpu: asr_model = asr_model.cuda() asr_model.eval() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary) wer_quant = evaluate(asr_model, labels_map, wer) logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}') if args.sensitivity: if wer_quant < args.wer_tolerance: logging.info("Tolerance is already met. Skip sensitivity analyasis.") return quant_layer_names = [] for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.disable() layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") if layer_name not in quant_layer_names: quant_layer_names.append(layer_name) logging.info(F"{len(quant_layer_names)} quantized layers found.") # Build sensitivity profile quant_layer_sensitivity = {} for i, quant_layer in enumerate(quant_layer_names): logging.info(F"Enable {quant_layer}") for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name: module.enable() logging.info(F"{name:40}: {module}") # Eval the model wer_value = evaluate(asr_model, labels_map, wer) logging.info(F"WER: {wer_value}") quant_layer_sensitivity[quant_layer] = args.wer_tolerance - wer_value for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name: module.disable() logging.info(F"{name:40}: {module}") # Skip most sensitive layers until WER target is met for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.enable() quant_layer_sensitivity = collections.OrderedDict(sorted(quant_layer_sensitivity.items(), key=lambda x: x[1])) pprint(quant_layer_sensitivity) skipped_layers = [] for quant_layer, _ in quant_layer_sensitivity.items(): for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if quant_layer in name: logging.info(F"Disable {name}") if not quant_layer in skipped_layers: skipped_layers.append(quant_layer) module.disable() wer_value = evaluate(asr_model, labels_map, wer) if wer_value <= args.wer_tolerance: logging.info( F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers." ) print(skipped_layers) return raise ValueError(f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!") if args.onnx: if args.asr_model.endswith("nemo"): onnx_name = args.asr_model.replace(".nemo", ".onnx") else: onnx_name = args.asr_model logging.info("Export to ", onnx_name) quant_nn.TensorQuantizer.use_fb_fake_quant = True asr_model.export(onnx_name, onnx_opset_version=13) quant_nn.TensorQuantizer.use_fb_fake_quant = False
def prepare_model(model_name, data_dir, per_channel_quantization, batch_size_train, batch_size_test, batch_size_onnx, calibrator, pretrained=True, ckpt_path=None, ckpt_url=None): """ Prepare the model for the classification flow. Arguments: model_name: name to use when accessing torchvision model dictionary data_dir: directory with train and val subdirs prepared "imagenet style" per_channel_quantization: iff true use per channel quantization for weights note that this isn't currently supported in ONNX-RT/Pytorch batch_size_train: batch size to use when training batch_size_test: batch size to use when testing in Pytorch batch_size_onnx: batch size to use when testing with ONNX-RT calibrator: calibration type to use (max/histogram) pretrained: if true a pretrained model will be loaded from torchvision ckpt_path: path to load a model checkpoint from, if not pretrained ckpt_url: url to download a model checkpoint from, if not pretrained and no path was given * at least one of {pretrained, path, url} must be valid The method returns a the following list: [ Model object, data loader for training, data loader for Pytorch testing, data loader for onnx testing ] """ # Use 'spawn' to avoid CUDA reinitialization with forked subprocess torch.multiprocessing.set_start_method('spawn') ## Initialize quantization, model and data loaders if per_channel_quantization: quant_desc_input = QuantDescriptor(calib_method=calibrator) quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) else: ## Force per tensor quantization for onnx runtime quant_desc_input = QuantDescriptor(calib_method=calibrator, axis=None) quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantConvTranspose2d.set_default_quant_desc_input( quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) quant_desc_weight = QuantDescriptor(calib_method=calibrator, axis=None) quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight) quant_nn.QuantConvTranspose2d.set_default_quant_desc_weight( quant_desc_weight) quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weight) quant_modules.initialize() model = torchvision.models.__dict__[model_name](pretrained=pretrained) if not pretrained: if ckpt_path: checkpoint = torch.load(ckpt_path) else: checkpoint = load_state_dict_from_url(ckpt_url) if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] elif 'model' in checkpoint.keys(): checkpoint = checkpoint['model'] model.load_state_dict(checkpoint) model.eval() model.cuda() ## Prepare the data loaders traindir = os.path.join(data_dir, 'train') valdir = os.path.join(data_dir, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data( traindir, valdir, False, False) data_loader_train = torch.utils.data.DataLoader( dataset, batch_size=batch_size_train, sampler=train_sampler, num_workers=16, pin_memory=True) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size_test, sampler=test_sampler, num_workers=4, pin_memory=True) data_loader_onnx = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size_onnx, sampler=test_sampler, num_workers=4, pin_memory=True) return model, data_loader_train, data_loader_test, data_loader_onnx
def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--wer_target", type=float, default=None, help="used by test") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") parser.add_argument( "--dont_normalize_text", default=False, action='store_false', help="Turn off trasnscript normalization. Recommended for non-English.", ) parser.add_argument( "--use_cer", default=False, action='store_true', help="Use Character Error Rate as the evaluation metric") parser.add_argument('--sensitivity', action="store_true", help="Perform sensitivity analysis") parser.add_argument('--onnx', action="store_true", help="Export to ONNX") parser.add_argument('--quant-disable-keyword', type=str, nargs='+', help='disable quantizers by keyword') args = parser.parse_args() torch.set_grad_enabled(False) quant_modules.initialize() if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.restore_from( restore_path=args.asr_model, override_config_path=asr_model_cfg) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model_cfg = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModelBPE.from_pretrained( model_name=args.asr_model, override_config_path=asr_model_cfg) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.dont_normalize_text, }) asr_model.preprocessor.featurizer.dither = 0.0 asr_model.preprocessor.featurizer.pad_to = 0 if can_gpu: asr_model = asr_model.cuda() asr_model.eval() if args.quant_disable_keyword: for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): for keyword in args.quant_disable_keyword: if keyword in name: logging.warning(F"Disable {name}") module.disable() labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) wer = WER(vocabulary=asr_model.decoder.vocabulary, use_cer=args.use_cer) wer_quant = evaluate(asr_model, labels_map, wer) logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}') if args.sensitivity: if wer_quant < args.wer_tolerance: logging.info( "Tolerance is already met. Skip sensitivity analyasis.") return quant_layer_names = [] for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.disable() layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") if layer_name not in quant_layer_names: quant_layer_names.append(layer_name) logging.info(F"{len(quant_layer_names)} quantized layers found.") # Build sensitivity profile quant_layer_sensitivity = {} for i, quant_layer in enumerate(quant_layer_names): logging.info(F"Enable {quant_layer}") for name, module in asr_model.named_modules(): if isinstance( module, quant_nn.TensorQuantizer) and quant_layer in name: module.enable() logging.info(F"{name:40}: {module}") # Eval the model wer_value = evaluate(asr_model, labels_map, wer) logging.info(F"WER: {wer_value}") quant_layer_sensitivity[ quant_layer] = args.wer_tolerance - wer_value for name, module in asr_model.named_modules(): if isinstance( module, quant_nn.TensorQuantizer) and quant_layer in name: module.disable() logging.info(F"{name:40}: {module}") # Skip most sensitive layers until WER target is met for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): module.enable() quant_layer_sensitivity = collections.OrderedDict( sorted(quant_layer_sensitivity.items(), key=lambda x: x[1])) pprint(quant_layer_sensitivity) skipped_layers = [] for quant_layer, _ in quant_layer_sensitivity.items(): for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if quant_layer in name: logging.info(F"Disable {name}") if not quant_layer in skipped_layers: skipped_layers.append(quant_layer) module.disable() wer_value = evaluate(asr_model, labels_map, wer) if wer_value <= args.wer_tolerance: logging.info( F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers." ) print(skipped_layers) export_onnx(args, asr_model) return raise ValueError( f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!" ) export_onnx(args, asr_model)