def compare_and_validate_results(float_model, q_model): weight_dict = compare_weights( float_model.state_dict(), q_model.state_dict() ) self.assertEqual(len(weight_dict), 1) for k, v in weight_dict.items(): self.assertTrue(v["float"].shape == v["quantized"].shape)
def _test_vision_model(self, float_model): float_model.to('cpu') float_model.eval() float_model.fuse_model() float_model.qconfig = torch.quantization.default_qconfig img_data = [(torch.rand(2, 3, 224, 224, dtype=torch.float), torch.randint(0, 1, (2, ), dtype=torch.long)) for _ in range(2)] qmodel = quantize(float_model, torch.quantization.default_eval_fn, [img_data], inplace=False) wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict()) def compute_error(x, y): Ps = torch.norm(x) Pn = torch.norm(x - y) return 20 * torch.log10(Ps / Pn) data = img_data[0][0] # Take in floating point and quantized model as well as input data, and returns a dict, with keys # corresponding to the quantized module names and each entry being a dictionary with two keys 'float' and # 'quantized', containing the activations of floating point and quantized model at matching locations. act_compare_dict = compare_model_outputs(float_model, qmodel, data) for key in act_compare_dict: compute_error(act_compare_dict[key]['float'][0], act_compare_dict[key]['quantized'][0].dequantize()) prepare_model_outputs(float_model, qmodel) for data in img_data: float_model(data[0]) qmodel(data[0]) # Find the matching activation between floating point and quantized modules, and return a dict with key # corresponding to quantized module names and each entry being a dictionary with two keys 'float' # and 'quantized', containing the matching floating point and quantized activations logged by the logger act_compare_dict = get_matching_activations(float_model, qmodel)