def compare_weights_fx(float_dict, quantized_dict): r"""Compare the weights of the float module (after prepare) with its corresponding quantized module. Return a dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights. This dict can be used to compare and compute the quantization error of the weights of float and quantized models. Note the float module is the float module which has been prepared by calling prepare_fx Example usage: prepared_model = prepare_fx(float_model, qconfig_dict) prepared_float_model = copy.deepcopy(prepared_model) quantized_model = convert_fx(prepared_float_model) qmodel = quantized_model wt_compare_dict = compare_weights_fx(prepared_float_model.state_dict(), qmodel.state_dict()) for key in wt_compare_dict: print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize())) Args: float_dict: state dict of the float model (after prepare) quantized_dict: state dict of the quantized model Return: weight_dict: dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights """ torch._C._log_api_usage_once( "quantization_api._numeric_suite_fx.compare_weights_fx") return compare_weights(float_dict, quantized_dict)
def compare_and_validate_results(float_model, q_model): weight_dict = compare_weights( float_model.state_dict(), q_model.state_dict() ) self.assertEqual(len(weight_dict), 1) for k, v in weight_dict.items(): self.assertTrue(v["float"].shape == v["quantized"].shape)
def test_compare_weights(self): r"""Compare the weights of float and quantized conv layer """ # eager mode annotated_conv_model = AnnotatedConvModel().eval() quantized_annotated_conv_model = quantize(annotated_conv_model, default_eval_fn, self.img_data) weight_dict = compare_weights( annotated_conv_model.state_dict(), quantized_annotated_conv_model.state_dict(), ) self.assertEqual(len(weight_dict), 1) for k, v in weight_dict.items(): self.assertTrue(v["float"].shape == v["quantized"].shape)
float_model.to('cpu') float_model.eval() float_model.fuse_model() float_model.qconfig = torch.quantization.default_qconfig img_data = [(torch.rand(2, 3, 10, 10, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)] qmodel = quantize(float_model, default_eval_fn, [img_data], inplace=False) ############################################################################## # 1. Compare the weights of float and quantized models # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # The first thing we usually want to compare are the weights of quantized model and float model. # We can call ``compare_weights()`` from PyTorch Numeric Suite to get a dictionary ``wt_compare_dict`` with key corresponding to module names and each entry is a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights. # ``compare_weights()`` takes in floating point and quantized state dict and returns a dict, with keys corresponding to the # floating point weights and values being a dictionary of floating point and quantized weights wt_compare_dict = ns.compare_weights(float_model.state_dict(), qmodel.state_dict()) print('keys of wt_compare_dict:') print(wt_compare_dict.keys()) print("\nkeys of wt_compare_dict entry for conv1's weight:") print(wt_compare_dict['conv1.weight'].keys()) print(wt_compare_dict['conv1.weight']['float'].shape) print(wt_compare_dict['conv1.weight']['quantized'].shape) ############################################################################## # Once get ``wt_compare_dict``, users can process this dictionary in whatever way they want. Here as an example we compute the quantization error of the weights of float and quantized models as following. # Compute the Signal-to-Quantization-Noise Ratio (SQNR) of the quantized tensor ``y``. The SQNR reflects the # relationship between the maximum nominal signal strength and the quantization error introduced in the # quantization. Higher SQNR corresponds to lower quantization error.