Exemple #1
0
    def test_compare_model_outputs_functional_static(self):
        r"""Compare the output of functional layer in static quantized model and corresponding
        output of conv layer in float model
        """
        qengine = torch.backends.quantized.engine

        model = ModelWithFunctionals().eval()
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
        q_model = prepare(model, inplace=False)
        q_model(self.img_data_2d[0][0])
        q_model = convert(q_model)
        act_compare_dict = compare_model_outputs(model, q_model,
                                                 self.img_data_2d[0][0])
        self.assertEqual(len(act_compare_dict), 7)
        expected_act_compare_dict_keys = {
            "mycat.stats",
            "myadd.stats",
            "mymul.stats",
            "myadd_relu.stats",
            "my_scalar_add.stats",
            "my_scalar_mul.stats",
            "quant.stats",
        }
        self.assertTrue(
            act_compare_dict.keys() == expected_act_compare_dict_keys)
        for k, v in act_compare_dict.items():
            self.assertTrue(len(v["float"]) == len(v["quantized"]))
            for i, val in enumerate(v["quantized"]):
                self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
        def compare_and_validate_results(float_model, q_model, data):
            act_compare_dict = compare_model_outputs(float_model, q_model, data)
            expected_act_compare_dict_keys = {"conv.stats", "quant.stats"}

            self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys)
            for k, v in act_compare_dict.items():
                self.assertTrue(v["float"][0].shape == v["quantized"][0].shape)
        def compare_and_validate_results(float_model, q_model, data):
            act_compare_dict = compare_model_outputs(float_model, q_model, data)
            expected_act_compare_dict_keys = {"fc1.stats"}

            self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys)
            for k, v in act_compare_dict.items():
                self.assertTrue(len(v["float"]) == len(v["quantized"]))
                for i, val in enumerate(v["quantized"]):
                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
Exemple #4
0
        def compare_and_validate_results(float_model, q_model, input, hidden):
            act_compare_dict = compare_model_outputs(float_model, q_model,
                                                     input, hidden)
            expected_act_compare_dict_keys = {"lstm.stats"}

            self.assertTrue(
                act_compare_dict.keys() == expected_act_compare_dict_keys)
            for k, v in act_compare_dict.items():
                self.assertTrue(len(v["float"]) == len(v["quantized"]))
                for i, val in enumerate(v["quantized"]):
                    self.assertTrue(
                        len(v["float"][i]) == len(v["quantized"][i]))
                    if i == 0:
                        self.assertTrue(v["float"][i][0].shape ==
                                        v["quantized"][i][0].shape)
                    else:
                        self.assertTrue(v["float"][i][0].shape ==
                                        v["quantized"][i][0].shape)
                        self.assertTrue(v["float"][i][1].shape ==
                                        v["quantized"][i][1].shape)
    def _test_vision_model(self, float_model):
        float_model.to('cpu')
        float_model.eval()
        float_model.fuse_model()
        float_model.qconfig = torch.quantization.default_qconfig
        img_data = [(torch.rand(2, 3, 224, 224, dtype=torch.float),
                     torch.randint(0, 1, (2, ), dtype=torch.long))
                    for _ in range(2)]
        qmodel = quantize(float_model,
                          torch.quantization.default_eval_fn, [img_data],
                          inplace=False)

        wt_compare_dict = compare_weights(float_model.state_dict(),
                                          qmodel.state_dict())

        def compute_error(x, y):
            Ps = torch.norm(x)
            Pn = torch.norm(x - y)
            return 20 * torch.log10(Ps / Pn)

        data = img_data[0][0]
        # Take in floating point and quantized model as well as input data, and returns a dict, with keys
        # corresponding to the quantized module names and each entry being a dictionary with two keys 'float' and
        # 'quantized', containing the activations of floating point and quantized model at matching locations.
        act_compare_dict = compare_model_outputs(float_model, qmodel, data)

        for key in act_compare_dict:
            compute_error(act_compare_dict[key]['float'][0],
                          act_compare_dict[key]['quantized'][0].dequantize())

        prepare_model_outputs(float_model, qmodel)

        for data in img_data:
            float_model(data[0])
            qmodel(data[0])

        # Find the matching activation between floating point and quantized modules, and return a dict with key
        # corresponding to quantized module names and each entry being a dictionary with two keys 'float'
        # and 'quantized', containing the matching floating point and quantized activations logged by the logger
        act_compare_dict = get_matching_activations(float_model, qmodel)