コード例 #1
0
    def test_compare_model_outputs_linear_static(self):
        r"""Compare the output of linear layer in static quantized model and corresponding
        output of conv layer in float model
        """
        qengine = torch.backends.quantized.engine

        def compare_and_validate_results(float_model, q_model, data):
            act_compare_dict = compare_model_outputs(float_model, q_model,
                                                     data)
            expected_act_compare_dict_keys = {
                "fc1.quant.stats", "fc1.module.stats"
            }

            self.assertTrue(
                act_compare_dict.keys() == expected_act_compare_dict_keys)
            for k, v in act_compare_dict.items():
                self.assertTrue(len(v["float"]) == len(v["quantized"]))
                for i, val in enumerate(v["quantized"]):
                    self.assertTrue(
                        v["float"][i].shape == v["quantized"][i].shape)

        linear_data = self.calib_data[0][0]
        model_list = [AnnotatedSingleLayerLinearModel(qengine)]
        for model in model_list:
            model.eval()
            if hasattr(model, "fuse_model"):
                model.fuse_model()
            q_model = quantize(model, test_only_eval_fn, [self.calib_data])
            compare_and_validate_results(model, q_model, linear_data)
コード例 #2
0
    def test_compare_model_stub_linear_static(self):
        r"""Compare the output of static quantized linear layer and its float shadow module"""

        qengine = torch.backends.quantized.engine

        def compare_and_validate_results(float_model, q_model,
                                         module_swap_list, data):
            ob_dict = compare_model_stub(float_model, q_model,
                                         module_swap_list, data)
            self.assertEqual(len(ob_dict), 1)
            for k, v in ob_dict.items():
                self.assertTrue(len(v["float"]) == len(v["quantized"]))
                for i, val in enumerate(v["quantized"]):
                    self.assertTrue(
                        v["float"][i].shape == v["quantized"][i].shape)

        linear_data = self.calib_data[0][0]
        module_swap_list = [nn.Linear]
        model_list = [AnnotatedSingleLayerLinearModel(qengine)]
        for model in model_list:
            model.eval()
            if hasattr(model, "fuse_model"):
                model.fuse_model()
            q_model = quantize(model, test_only_eval_fn, [self.calib_data])
            compare_and_validate_results(model, q_model, module_swap_list,
                                         linear_data)
コード例 #3
0
    def test_compare_model_outputs_linear_static(self):
        r"""Compare the output of linear layer in static quantized model and corresponding
        output of conv layer in float model
        """
        qengine = torch.backends.quantized.engine

        def compare_and_validate_results(float_model, q_model, data):
            act_compare_dict = compare_model_outputs(float_model, q_model,
                                                     data)
            expected_act_compare_dict_keys = {
                "fc1.quant.stats", "fc1.module.stats"
            }

            self.assertTrue(
                act_compare_dict.keys() == expected_act_compare_dict_keys)
            for k, v in act_compare_dict.items():
                self.assertTrue(v["float"].shape == v["quantized"].shape)

        img_data = [(
            torch.rand(3, 5, dtype=torch.float),
            torch.randint(0, 1, (2, ), dtype=torch.long),
        ) for _ in range(2)]
        linear_data = img_data[0][0]
        model_list = [AnnotatedSingleLayerLinearModel(qengine)]
        for model in model_list:
            model.eval()
            if hasattr(model, "fuse_model"):
                model.fuse_model()
            q_model = quantize(model, default_eval_fn, img_data)
            compare_and_validate_results(model, q_model, linear_data)
コード例 #4
0
    def test_compare_model_stub_linear_static(self):
        r"""Compare the output of static quantized linear layer and its float shadow module
        """

        qengine = torch.backends.quantized.engine

        def compare_and_validate_results(float_model, q_model,
                                         module_swap_list, data):
            ob_dict = compare_model_stub(float_model, q_model,
                                         module_swap_list, data, ShadowLogger)
            self.assertEqual(len(ob_dict), 1)
            for k, v in ob_dict.items():
                self.assertTrue(v["float"].shape == v["quantized"].shape)

        img_data = [(
            torch.rand(3, 5, dtype=torch.float),
            torch.randint(0, 1, (2, ), dtype=torch.long),
        ) for _ in range(2)]
        linear_data = img_data[0][0]
        module_swap_list = [nn.Linear]
        model_list = [AnnotatedSingleLayerLinearModel(qengine)]
        for model in model_list:
            model.eval()
            if hasattr(model, "fuse_model"):
                model.fuse_model()
            q_model = quantize(model, default_eval_fn, img_data)
            compare_and_validate_results(model, q_model, module_swap_list,
                                         linear_data)
コード例 #5
0
    def test_record_observer(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = AnnotatedSingleLayerLinearModel()
                model.qconfig = default_debug_qconfig
                model = prepare(model)
                # run the evaluation and dump all tensors
                test_only_eval_fn(model, self.calib_data)
                test_only_eval_fn(model, self.calib_data)
                observer_dict = {}
                get_observer_dict(model, observer_dict)

                self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
                                'observer is not recorded in the dict')
                self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()),
                                 2 * len(self.calib_data))
                self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0],
                                 model(self.calib_data[0][0]))
コード例 #6
0
    def test_compare_weights_linear_static(self):
        r"""Compare the weights of float and static quantized linear layer"""

        qengine = torch.backends.quantized.engine

        def compare_and_validate_results(float_model, q_model):
            weight_dict = compare_weights(float_model.state_dict(),
                                          q_model.state_dict())
            self.assertEqual(len(weight_dict), 1)
            for k, v in weight_dict.items():
                self.assertTrue(v["float"].shape == v["quantized"].shape)

        model_list = [AnnotatedSingleLayerLinearModel(qengine)]
        for model in model_list:
            model.eval()
            if hasattr(model, "fuse_model"):
                model.fuse_model()
            q_model = quantize(model, test_only_eval_fn, [self.calib_data])
            compare_and_validate_results(model, q_model)
コード例 #7
0
    def test_compare_weights_linear_static(self):
        r"""Compare the weights of float and static quantized linear layer
        """

        qengine = torch.backends.quantized.engine

        def compare_and_validate_results(float_model, q_model):
            weight_dict = compare_weights(float_model.state_dict(),
                                          q_model.state_dict())
            self.assertEqual(len(weight_dict), 1)
            for k, v in weight_dict.items():
                self.assertTrue(v["float"].shape == v["quantized"].shape)

        img_data = [(
            torch.rand(3, 5, dtype=torch.float),
            torch.randint(0, 1, (2, ), dtype=torch.long),
        ) for _ in range(2)]
        model_list = [AnnotatedSingleLayerLinearModel(qengine)]
        for model in model_list:
            model.eval()
            if hasattr(model, "fuse_model"):
                model.fuse_model()
            q_model = quantize(model, default_eval_fn, img_data)
            compare_and_validate_results(model, q_model)