def quantize(model): qconfig = get_default_qconfig("fbgemm") qconfig_dict = {"": qconfig} return convert_fx(prepare_fx(model, qconfig_dict))
def test_input_weight_equalization_graphs(self): """ Tests that the modified model for equalization has the same graph structure as the model without equalization (before and after quantization). """ linear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] linearAdd_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] linear2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] functionalLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] functionalLinearAdd_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] functionalLinear2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] linearRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_method('dequantize') ] linearReluLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] functionalLinearRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear_relu), ns.call_method('dequantize') ] functionalLinearReluLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear_relu), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] conv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] conv2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] functionalConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] functionalConv2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] convRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.ConvReLU2d), ns.call_method('dequantize') ] convReluConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.ConvReLU2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] functionalConvRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d_relu), ns.call_method('dequantize') ] functionalConvReluConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d_relu), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] tests = [ (SingleLayerLinearModel, 2, linear_node_list), (LinearAddModel, 2, linearAdd_node_list), (TwoLayerLinearModel, 2, linear2_node_list), (SingleLayerFunctionalLinearModel, 2, functionalLinear_node_list), (FunctionalLinearAddModel, 2, functionalLinearAdd_node_list), (TwoLayerFunctionalLinearModel, 2, functionalLinear2_node_list), (LinearReluModel, 2, linearRelu_node_list), (LinearReluLinearModel, 2, linearReluLinear_node_list), (FunctionalLinearReluModel, 2, functionalLinearRelu_node_list), (FunctionalLinearReluLinearModel, 2, functionalLinearReluLinear_node_list), (ConvModel, 4, conv_node_list), (TwoLayerConvModel, 4, conv2_node_list), (SingleLayerFunctionalConvModel, 4, functionalConv_node_list), (TwoLayerFunctionalConvModel, 4, functionalConv2_node_list), (ConvReluModel, 4, convRelu_node_list), (ConvReluConvModel, 4, convReluConv_node_list), (FunctionalConvReluModel, 4, functionalConvRelu_node_list), (FunctionalConvReluConvModel, 4, functionalConvReluConv_node_list) ] for (M, ndim, node_list) in tests: m = M().eval() if ndim == 2: x = torch.rand((5, 5)) elif ndim == 4: x = torch.rand((16, 3, 224, 224)) prepared = prepare_fx( m, qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) prepared(x) equalized_quantized_model = convert_fx(prepared) # Check the order of nodes in the graph self.checkGraphModuleNodes(equalized_quantized_model, expected_node_list=node_list)
def load(checkpoint_dir, model, **kwargs): """Execute the quantize process on the specified model. Args: checkpoint_dir (dir): The folder of checkpoint. 'best_configure.yaml' and 'best_model_weights.pt' are needed in This directory. 'checkpoint' dir is under workspace folder and workspace folder is define in configure yaml file. model (object): fp32 model need to do quantization. Returns: (object): quantized model """ tune_cfg_file = os.path.join( os.path.abspath(os.path.expanduser(checkpoint_dir)), 'best_configure.yaml') weights_file = os.path.join( os.path.abspath(os.path.expanduser(checkpoint_dir)), 'best_model_weights.pt') assert os.path.exists( tune_cfg_file), "tune configure file %s didn't exist" % tune_cfg_file assert os.path.exists( weights_file), "weight file %s didn't exist" % weights_file with open(tune_cfg_file, 'r') as f: tune_cfg = yaml.safe_load(f) version = get_torch_version() if tune_cfg['approach'] != "post_training_dynamic_quant": if version < '1.7': q_mapping = tq.default_mappings.DEFAULT_MODULE_MAPPING elif version < '1.8': q_mapping = \ tq.quantization_mappings.get_static_quant_module_mappings() else: q_mapping = \ tq.quantization_mappings.get_default_static_quant_module_mappings() else: if version < '1.7': q_mapping = \ tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING elif version < '1.8': q_mapping = \ tq.quantization_mappings.get_dynamic_quant_module_mappings() else: q_mapping = \ tq.quantization_mappings.get_default_dynamic_quant_module_mappings() if version < '1.7': white_list = \ tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING \ if tune_cfg['approach'] == 'post_training_dynamic_quant' else \ tq.default_mappings.DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST - \ {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding} elif version < '1.8': white_list = \ tq.quantization_mappings.get_dynamic_quant_module_mappings() \ if tune_cfg['approach'] == 'post_training_dynamic_quant' else \ tq.quantization_mappings.get_qconfig_propagation_list() - \ {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding} else: white_list = \ tq.quantization_mappings.get_default_dynamic_quant_module_mappings() \ if tune_cfg['approach'] == 'post_training_dynamic_quant' else \ tq.quantization_mappings.get_default_qconfig_propagation_list() - \ {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding} if tune_cfg['approach'] == "post_training_dynamic_quant": op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg['approach']) else: op_cfgs = _cfg_to_qconfig(tune_cfg) if tune_cfg['framework'] == "pytorch_fx": # pragma: no cover # For torch.fx approach assert version >= '1.8', \ "Please use PyTroch 1.8 or higher version with pytorch_fx backend" from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx q_model = copy.deepcopy(model.eval()) fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg['approach']) if tune_cfg['approach'] == "quant_aware_training": q_model.train() q_model = prepare_qat_fx( q_model, fx_op_cfgs, prepare_custom_config_dict=kwargs if kwargs != {} else None) else: q_model = prepare_fx( q_model, fx_op_cfgs, prepare_custom_config_dict=kwargs if kwargs != {} else None) q_model = convert_fx(q_model) weights = torch.load(weights_file) q_model.load_state_dict(weights) return q_model q_model = copy.deepcopy(model.eval()) _propagate_qconfig(q_model, op_cfgs, white_list=white_list, approach=tune_cfg['approach']) # sanity check common API misusage if not any(hasattr(m, 'qconfig') and m.qconfig for m in q_model.modules()): logger.warn( "None of the submodule got qconfig applied. Make sure you " "passed correct configuration through `qconfig_dict` or " "by assigning the `.qconfig` attribute directly on submodules") if tune_cfg['approach'] != "post_training_dynamic_quant": add_observer_(q_model) q_model = convert(q_model, mapping=q_mapping, inplace=True) weights = torch.load(weights_file) q_model.load_state_dict(weights) return q_model
def prep_qat_eval(self): self.model = quantize_fx.convert_fx(self.model) if self.jit: self.model = torch.jit.script(self.model) self.model.eval()
def _test_auto_tracing( self, m, qconfig, example_args, fuse_modules=True, do_fx_comparison=True, do_torchscript_checks=True, ): m_copy = copy.deepcopy(m) qconfig_dict = {'': qconfig} mp = _quantize_dbr.prepare(m, qconfig_dict, example_args, fuse_modules=fuse_modules) out_p = mp(*example_args) # print(mp) mq = _quantize_dbr.convert(mp) # print(mq) # verify it runs out_q = mq(*example_args) # print(out_q) # compare it against FX if do_fx_comparison: m_copy_p = prepare_fx(m_copy, {'': qconfig}) out_m_copy_p = m_copy_p(*example_args) # print(m_copy_p) m_copy_q = convert_fx(m_copy_p) # print(m_copy_q) # print(m_copy_q.graph) out_q_fx = m_copy_q(*example_args) # print(out_q) # print(out_q_fx) self.assertTrue(_allclose(out_p, out_m_copy_p)) # print(out_q) # print(out_q_fx) self.assertTrue(_allclose(out_q, out_q_fx)) if do_torchscript_checks: # verify torch.jit.trace works mq_jit_traced = torch.jit.trace(mq, example_args, check_trace=False) # print(mq_jit_traced.graph) traced_out = mq_jit_traced(*example_args) self.assertTrue(_allclose(traced_out, out_q)) # verify torch.jit.script works rewritten = mq.rewrite_for_scripting() rewritten_out = rewritten(*example_args) # print(rewritten) self.assertTrue(_allclose(rewritten_out, out_q)) scripted_rewritten = torch.jit.script(rewritten) # print(scripted_rewritten.graph) scripted_rewritten_out = scripted_rewritten(*example_args) # print('scripted_rewritten_out', scripted_rewritten_out) self.assertTrue(_allclose(scripted_rewritten_out, out_q)) traced_rewritten = torch.jit.trace(rewritten, example_args, check_trace=False) traced_rewritten_out = traced_rewritten(*example_args) self.assertTrue(_allclose(traced_rewritten_out, out_q))
def test_selective_equalization(self): """ Tests that we are able to run numeric suite on the equalized model and construct a valid equalization_qconfig_dict equalizing only the top 4 layers with the highest quantization errors. """ torch.manual_seed(1) class M(nn.Module): def __init__(self): super().__init__() self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5)) self.top = torch.nn.Sequential(torch.nn.Linear(5, 5)) def forward(self, x): x = self.bot(x) x = torch.add(x, 5) x = self.top(x) return x float_model = M().eval() # Hard coded so that the top layer has a higher quantization error x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957], [0.8373, 0.8851, 0.8229, 0.0212, 0.8987], [0.9077, 0.7538, 0.4530, 0.5772, 0.1376], [0.0690, 0.9002, 0.7998, 0.2768, 0.8985], [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]]) # Quantize the float model prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict) prepared_model(x) quantized_model = convert_fx(copy.deepcopy(prepared_model)) # Get the SQNR between the float and quantized model layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model), quantized_model, x) # Construct the equalization_qconfig_dict equalizing layers with the highest # quantization errors selective_equalization_qconfig_dict = get_equalization_qconfig_dict( layer_to_sqnr_dict, 1) # Create the selectively equalized model prepared_model = prepare_fx( copy.deepcopy(float_model), specific_qconfig_dict, equalization_qconfig_dict=selective_equalization_qconfig_dict, ) prepared_model(x) equalized_model = convert_fx(prepared_model) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] # Check the order of nodes in the graph self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
float_model.eval() # deepcopy the model since we need to keep the original model around model_to_quantize = copy.deepcopy(float_model) model_to_quantize.eval() qconfig = get_default_qconfig("fbgemm") qconfig_dict = {"": qconfig} prepared_model = prepare_fx(model_to_quantize, qconfig_dict) print(prepared_model.graph) calibrate(prepared_model, data_loader_train) quantized_model = convert_fx(prepared_model) print(quantized_model) script_module = torch.jit.trace(quantized_model, torch.randn(1, 3, 224, 224)).eval() print(script_module) torch._C._jit_pass_inline(script_module.graph) print("Size of model before quantization") print_size_of_model(float_model) print("Size of model after quantization") print_size_of_model(script_module) top1, top5 = evaluate(script_module, criterion, data_loader_test) print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) top1, top5 = evaluate(float_model, criterion, data_loader_test)
super(Net, self).__init__() repo = 'alantess/vigilant-driving:main/1.0.75' self.model = torch.hub.load(repo, 'segnet', pretrained=True) def forward(self, x): x = self.model(x).squeeze(0).argmax(0) return x.mul(100).clamp(0, 255) model_fp = Net() model_fp.eval() model_to_quantize = copy.deepcopy(model_fp) model_to_quantize.eval() qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')} model_to_quantize.eval() # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) # calibrate (not shown) # quantize model_quantized = quantize_fx.convert_fx(model_prepared) # Fusion model_to_quantize = copy.deepcopy(model_fp) model_fused = quantize_fx.fuse_fx(model_to_quantize) # Save model scipted_model = torch.jit.script(model_fused) scripted_optimized_moodel = optimize_for_mobile(scipted_model) torch.jit.save(scripted_optimized_moodel, "models/segnet_fx_mobile.pt")
def main(args): # data train_transform = tv.transforms.Compose([]) if args.data_augmentation: train_transform.transforms.append( tv.transforms.RandomCrop(32, padding=4)) train_transform.transforms.append(tv.transforms.RandomHorizontalFlip()) train_transform.transforms.append(tv.transforms.ToTensor()) normalize = tv.transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) train_transform.transforms.append(normalize) test_transform = tv.transforms.Compose( [tv.transforms.ToTensor(), normalize]) train_dataset = tv.datasets.CIFAR10(root='data/', train=True, transform=train_transform, download=True) test_dataset = tv.datasets.CIFAR10(root='data/', train=False, transform=test_transform, download=True) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.bs, shuffle=True, pin_memory=True, num_workers=4) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.bs, shuffle=False, pin_memory=True, num_workers=4) # net net = tv.models.mobilenet_v2(num_classes=10) net.load_state_dict(torch.load('mobilenet_v2.pth', map_location='cpu')) net.dropout = torch.nn.Sequential() # quantization model_to_quantize = copy.deepcopy(net).to(device) qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')} model_to_quantize.train() model_prepared = prepare_qat_fx(model_to_quantize, qconfig_dict) # optimizer and loss optimizer = torch.optim.SGD(model_prepared.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 16], 0.1) criterion = torch.nn.CrossEntropyLoss().to(device) # train model_prepared.train() for i_epoch in range(20): time_s = time.time() for i_iter, data in enumerate(train_loader): img, label = data img, label = img.to(device), label.to(device) optimizer.zero_grad() feat = model_prepared(img) loss = criterion(feat, label) loss.backward() optimizer.step() time_e = time.time() print( 'Epoch:{:3}/20 || Iter: {:4}/{} || ' 'Loss: {:2.4f} ' 'ETA: {:.2f}min'.format(i_epoch + 1, i_iter + 1, len(train_loader), loss.item(), (time_e - time_s) * (20 - i_epoch) * len(train_loader) / (i_iter + 1) / 60)) scheduler.step() # to int8 model_int8 = convert_fx(model_prepared) torch.jit.save(torch.jit.script(model_int8), 'int8-qat.pth') # valid loaded_quantized_model = torch.jit.load('int8-qat.pth') correct = 0. total = 0. with torch.no_grad(): loaded_quantized_model.eval() for images, labels in tqdm(test_loader): images = images labels = labels pred = loaded_quantized_model(images) pred = torch.max(pred.data, 1)[1] total += labels.size(0) correct += (pred == labels).sum().item() val_acc = correct / total print(val_acc)
def convert(self, model, submodules, attrs): model.another_layer = convert_fx(model.another_layer) return model
def prepare_for_quant_convert(self, cfg): self.avgpool = convert_fx( self.avgpool, convert_custom_config_dict=self.custom_config_dict) return self
def quantize_fx(model, inputs, data_loader, dynamic=True, selective=False): if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder): static = not dynamic if dynamic: qconfig = per_channel_dynamic_qconfig else: qconfig = QConfig( activation=HistogramObserver.with_args(reduce_range=False), weight=default_weight_observer, ) # Only linear layers qconfig_dict = {"": None} if static and selective: qconfig_dict["module_name"] = [] layers = model.encoder.encoder.transformer.layers.layers.layers layers_str = "layers" # skip first layer for layer_idx in range(1, len(layers)): qconfig_dict["module_name"].append(( layers_str + ".{}.attention.input_projection".format(layer_idx), qconfig, )) qconfig_dict["module_name"].append(( layers_str + ".{}.attention.output_projection".format(layer_idx), qconfig, )) for mlp_idx, m in enumerate( layers[layer_idx].residual_mlp.mlp): # Only quantize first linear otherwise there are accuarcy issues with static quantization if type(m) == torch.nn.Linear and mlp_idx < 1: qconfig_dict["module_name"].append(( layers_str + ".{}.residual_mlp.mlp.{}".format( layer_idx, mlp_idx), qconfig, )) else: qconfig_dict["object_type"] = [(torch.nn.Linear, qconfig)] def calibrate(model, loader, max_samples=-1): model.eval() with torch.no_grad(): for (idx, d) in enumerate(loader): print("Running sample input #" + str(idx)) model(d[1]["tokens"]) if idx == max_samples: break prepared_model = prepare_fx( model.encoder.encoder.transformer.layers.layers, qconfig_dict) # fuse modules and insert observers model.encoder.encoder.transformer.layers.layers = prepared_model if static: calibrate(model, data_loader) # run calibration on sample data model.encoder.encoder.transformer.layers.layers = convert_fx( prepared_model) # Trace the submodule in order to fix the interface if static: input1 = torch.randn([2, 1, 1024], dtype=torch.float) input2 = torch.randn([1, 2]).bool() traced = torch.jit.trace( model.encoder.encoder.transformer.layers.layers, (input1, input2)) model.encoder.encoder.transformer.layers.layers = traced # Trace the overall module trace = model.trace(inputs) return trace