Esempio n. 1
0
 def compare_and_validate_results(float_model, q_model):
     weight_dict = compare_weights(
         float_model.state_dict(), q_model.state_dict()
     )
     self.assertEqual(len(weight_dict), 1)
     for k, v in weight_dict.items():
         self.assertTrue(v["float"].shape == v["quantized"].shape)
    def _test_vision_model(self, float_model):
        float_model.to('cpu')
        float_model.eval()
        float_model.fuse_model()
        float_model.qconfig = torch.quantization.default_qconfig
        img_data = [(torch.rand(2, 3, 224, 224, dtype=torch.float),
                     torch.randint(0, 1, (2, ), dtype=torch.long))
                    for _ in range(2)]
        qmodel = quantize(float_model,
                          torch.quantization.default_eval_fn, [img_data],
                          inplace=False)

        wt_compare_dict = compare_weights(float_model.state_dict(),
                                          qmodel.state_dict())

        def compute_error(x, y):
            Ps = torch.norm(x)
            Pn = torch.norm(x - y)
            return 20 * torch.log10(Ps / Pn)

        data = img_data[0][0]
        # Take in floating point and quantized model as well as input data, and returns a dict, with keys
        # corresponding to the quantized module names and each entry being a dictionary with two keys 'float' and
        # 'quantized', containing the activations of floating point and quantized model at matching locations.
        act_compare_dict = compare_model_outputs(float_model, qmodel, data)

        for key in act_compare_dict:
            compute_error(act_compare_dict[key]['float'][0],
                          act_compare_dict[key]['quantized'][0].dequantize())

        prepare_model_outputs(float_model, qmodel)

        for data in img_data:
            float_model(data[0])
            qmodel(data[0])

        # Find the matching activation between floating point and quantized modules, and return a dict with key
        # corresponding to quantized module names and each entry being a dictionary with two keys 'float'
        # and 'quantized', containing the matching floating point and quantized activations logged by the logger
        act_compare_dict = get_matching_activations(float_model, qmodel)