コード例 #1
0
def compare_weights_fx(float_dict, quantized_dict):
    r"""Compare the weights of the float module (after prepare) with its corresponding quantized
    module. Return a dict with key corresponding to module names and each entry being
    a dictionary with two keys 'float' and 'quantized', containing the float and
    quantized weights. This dict can be used to compare and compute the quantization
    error of the weights of float and quantized models.

    Note the float module is the float module which has been prepared by calling prepare_fx

    Example usage:
        prepared_model = prepare_fx(float_model, qconfig_dict)
        prepared_float_model = copy.deepcopy(prepared_model)
        quantized_model = convert_fx(prepared_float_model)

        qmodel = quantized_model
        wt_compare_dict = compare_weights_fx(prepared_float_model.state_dict(), qmodel.state_dict())
        for key in wt_compare_dict:
            print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))

    Args:
        float_dict: state dict of the float model (after prepare)
        quantized_dict: state dict of the quantized model

    Return:
        weight_dict: dict with key corresponding to module names and each entry being
        a dictionary with two keys 'float' and 'quantized', containing the float and
        quantized weights
    """
    torch._C._log_api_usage_once(
        "quantization_api._numeric_suite_fx.compare_weights_fx")
    return compare_weights(float_dict, quantized_dict)
コード例 #2
0
 def compare_and_validate_results(float_model, q_model):
     weight_dict = compare_weights(
         float_model.state_dict(), q_model.state_dict()
     )
     self.assertEqual(len(weight_dict), 1)
     for k, v in weight_dict.items():
         self.assertTrue(v["float"].shape == v["quantized"].shape)
コード例 #3
0
 def test_compare_weights(self):
     r"""Compare the weights of float and quantized conv layer
     """
     # eager mode
     annotated_conv_model = AnnotatedConvModel().eval()
     quantized_annotated_conv_model = quantize(annotated_conv_model,
                                               default_eval_fn,
                                               self.img_data)
     weight_dict = compare_weights(
         annotated_conv_model.state_dict(),
         quantized_annotated_conv_model.state_dict(),
     )
     self.assertEqual(len(weight_dict), 1)
     for k, v in weight_dict.items():
         self.assertTrue(v["float"].shape == v["quantized"].shape)
コード例 #4
0
float_model.to('cpu')
float_model.eval()
float_model.fuse_model()
float_model.qconfig = torch.quantization.default_qconfig
img_data = [(torch.rand(2, 3, 10, 10, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
qmodel = quantize(float_model, default_eval_fn, [img_data], inplace=False)

##############################################################################
# 1. Compare the weights of float and quantized models
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# The first thing we usually want to compare are the weights of quantized model and float model.
# We can call ``compare_weights()`` from PyTorch Numeric Suite to get a dictionary ``wt_compare_dict`` with key corresponding to module names and each entry is a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights.
# ``compare_weights()`` takes in floating point and quantized state dict and returns a dict, with keys corresponding to the
# floating point weights and values being a dictionary of floating point and quantized weights

wt_compare_dict = ns.compare_weights(float_model.state_dict(), qmodel.state_dict())

print('keys of wt_compare_dict:')
print(wt_compare_dict.keys())

print("\nkeys of wt_compare_dict entry for conv1's weight:")
print(wt_compare_dict['conv1.weight'].keys())
print(wt_compare_dict['conv1.weight']['float'].shape)
print(wt_compare_dict['conv1.weight']['quantized'].shape)


##############################################################################
# Once get ``wt_compare_dict``, users can process this dictionary in whatever way they want. Here as an example we compute the quantization error of the weights of float and quantized models as following.
# Compute the Signal-to-Quantization-Noise Ratio (SQNR) of the quantized tensor ``y``. The SQNR reflects the
# relationship between the maximum nominal signal strength and the quantization error introduced in the
# quantization. Higher SQNR corresponds to lower quantization error.