Esempio n. 1
0
    def test_ops(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.linear = torch.nn.Linear(5, 5)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv(x)
                x = self.linear(x)
                x = x + 3
                x = self.relu(x)
                x = x + 6
                return x

        m = M().eval()
        m = prepare_fx(m, {"": default_qconfig})
        m = _convert_fx_do_not_use(m, is_reference=True)
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 5,
            ns.call_method("dequantize"): 5,
            ns.call_module(torch.nn.quantized._reference.Linear): 1,
            ns.call_module(torch.nn.quantized._reference.Conv2d): 1,
        }
        self.checkGraphModuleNodes(
            m, expected_node_occurrence=expected_occurrence)
Esempio n. 2
0
 def _test_module(self,
                  m,
                  inputs,
                  shape_ranges,
                  no_prepare=None,
                  no_convert=None):
     """
     Args:
       m: the float module we want to test
       inputs: list of inputs for the module
       shape_ranges: a list of shape_range, where every shape_range is a tuple of
       three tuples
       ((min_input_shape), (optimized_input_shape), (max_input_shape)).
       Each shape_range is used to populate a TensorRT optimization profile.
       e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize
       for (25, 224) because it's the most common input shape, then we set shape_ranges to
       ((1, 224), (25, 225), (100, 224))
       no_prepare: node occurrence after prepare
       no_convert: node occurrence after convert
     """
     m = m.eval()
     prepared = prepare_fx(m, {"": self.qconfig},
                           backend_config_dict=self.trt_backend_config_dict)
     self.checkGraphModuleNodes(prepared,
                                expected_node_occurrence=no_prepare)
     # calibration
     prepared(*inputs)
     quantized = _convert_fx_do_not_use(prepared, is_reference=True)
     self.checkGraphModuleNodes(quantized,
                                expected_node_occurrence=no_convert)
     # lower to trt
     trt_mod = lower_to_trt(quantized, inputs, shape_ranges)
     inputs_cuda = [i.cuda() for i in inputs]
     # make sure it runs
     trt_mod(*inputs_cuda)
Esempio n. 3
0
    def test_unsupported_qconfig(self):
        """ Check that we won't quantize the model if the qconfig is not supported
        """
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        trt_unsupported_qconfig = default_qconfig
        prepared = prepare_fx(m, {"": trt_unsupported_qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        # calibration
        prepared(linear_module_input)
        quantized = _convert_fx_do_not_use(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 0,
            ns.call_method("dequantize"): 0,
            ns.call_module(torch.nn.Linear): 1,
            ns.call_module(torch.nn.quantized._reference.Linear): 0,
        }
        # check model is not quantized
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
Esempio n. 4
0
    def _test_quantized_inputs_outputs(
            self, prepare_custom_config_dict, prepare_count_check,
            convert_count_check):
        """
        Test the option to have inputs and outputs of the graph quantized
        """
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        # quantized input, quantized output
        m = M()
        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
        m.eval()
        mp = torch.ao.quantization.quantize_fx.prepare_fx(
            m, qconfig_dict,
            prepare_custom_config_dict=prepare_custom_config_dict)
        self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check)
        mp(torch.randn(1, 1, 4, 4))
        mq = _convert_fx_do_not_use(
            mp, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
        self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check)
Esempio n. 5
0
    def test_addmm(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.weight = torch.randn(5, 5)
                self.bias = torch.randn(5)

            def forward(self, x):
                return torch.addmm(self.bias, x, self.weight)

        m = M().eval()
        prepared = prepare_fx(
            m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # weight
            ns.call_module(torch.ao.quantization.MinMaxObserver): 1,
            # activation
            ns.call_module(torch.ao.quantization.HistogramObserver): 2,
        }
        self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
        quantized = _convert_fx_do_not_use(
            prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # input activation, output activation and weight
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_function(torch.addmm): 1,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
Esempio n. 6
0
    def test_conv_add(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)

            def forward(self, x, y):
                return self.conv(x) + y

        weighted_op_qint8_dtype_config = {
            # optional, input activation dtype
            "input_dtype": torch.qint8,
            # optional, weight dtype
            "weight_dtype": torch.qint8,
            # optional, bias dtype
            "bias_dtype": torch.float,
            # optional, output activation dtype
            "output_dtype": torch.qint8
        }

        conv_add_config = {
            "pattern": (operator.add, torch.nn.Conv2d, MatchAllNode),
            "observation_type":
            ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_qint8_dtype_config,
            ],
            "root_module":
            torch.nn.Conv2d,
            "reference_quantized_module_for_root":
            torch.nn.quantized._reference.Conv2d,
        }

        m = M().eval()
        modified_backend_config_dict = copy.deepcopy(
            self.trt_backend_config_dict)
        modified_backend_config_dict["configs"].insert(0, conv_add_config)
        m = prepare_fx(m, {"": self.qconfig},
                       backend_config_dict=modified_backend_config_dict)
        node_occurrence = {
            ns.call_module(torch.ao.quantization.HistogramObserver): 3,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        m = _convert_fx_do_not_use(
            m,
            is_reference=True,
            backend_config_dict=modified_backend_config_dict)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
Esempio n. 7
0
    def test_cat(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.cat([x, x], 1)

        m = M().eval()
        prepared = prepare_fx(m, {"": self.qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        self.assertTrue(len(dict(prepared.named_children())) == 1)
        quantized = _convert_fx_do_not_use(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_function(torch.cat): 1,
            ns.call_method("dequantize"): 2,
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
Esempio n. 8
0
    def test_conv_add_standalone_module(self):
        class Standalone(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x, y):
                return self.relu(self.conv(x) + y)

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.standalone = Standalone()

            def forward(self, x, y):
                y = self.conv(x)
                return self.standalone(x, y)

        from torch.ao.quantization.fx.backend_config_dict.observation_type import ObservationType
        weighted_op_quint8_dtype_config = {
            # optional, input activation dtype
            # TODO: change back to torch.qint8 after input_quantized_idxs and output_quantized_idxs
            # are more flexible
            "input_dtype": torch.quint8,
            # optional, weight dtype
            "weight_dtype": torch.qint8,
            # optional, bias dtype
            "bias_dtype": torch.float,
            # optional, output activation dtype
            "output_dtype": torch.quint8
        }

        conv_add_config = {
            "pattern": (torch.nn.ReLU, (operator.add, torch.nn.Conv2d, MatchAllNode)),
            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_quint8_dtype_config,
            ],
            "root_module": torch.nn.Conv2d,
            # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
        }

        conv_config = {
            "pattern": torch.nn.Conv2d,
            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_quint8_dtype_config,
            ],
            "root_module": torch.nn.Conv2d,
            # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
        }

        m = M().eval()
        backend_config_dict = {
            "configs": [
                conv_add_config,
                conv_config,
            ]
        }
        prepare_custom_config_dict = {
            "standalone_module_name": [("standalone", None, {"input_quantized_idxs": [0, 1]}, None)]
        }
        # TODO: use self.qconfig after input_quantized_idxs and output_quantized_idxs
        # are more flexible
        qconfig = torch.ao.quantization.QConfig(
            activation=torch.ao.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
            ),
            weight=torch.ao.quantization.default_weight_observer
        )
        m = prepare_fx(
            m,
            {"": qconfig},
            prepare_custom_config_dict=prepare_custom_config_dict,
            backend_config_dict=backend_config_dict)
        node_occurrence = {
            # for input and output of conv, where input is used twice, once in conv and
            # once in standalone module
            ns.call_module(torch.ao.quantization.HistogramObserver): 2,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        standalone_node_occurrence = {
            # output of the standalone module
            ns.call_module(torch.ao.quantization.HistogramObserver): 1,
        }
        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
        m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_module(nn.Conv2d): 1,
            ns.call_method("dequantize"): 1,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        standalone_node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_module(nn.Conv2d): 1,
            ns.call_module(torch.nn.ReLU): 1,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
Esempio n. 9
0
    def _test_standalone_module(
            self,
            interface_config,
            prepare_count_check,
            standalone_prepare_count_check,
            convert_count_check,
            standalone_convert_count_check,
            qconfig=None,
            backend_config_dict=None):
        """ Test standalone module with different quantized input/quantized output
        configurations
        """
        class StandaloneModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                return self.conv(x)

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)
                self.standalone = StandaloneModule()

            def forward(self, x):
                x = self.conv(x)
                x = self.standalone(x)
                return x

        class RefM(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        if backend_config_dict is None:
            backend_config_dict = self.trt_backend_config_dict
        if qconfig is None:
            qconfig = self.trt_qconfig

        data = torch.randn(1, 1, 1, 1)
        # instantiate M and RefM and align the parameters
        original_m = M().eval()
        original_ref_m = RefM().eval()
        original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
        original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
        original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
        original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())

        prepare_config = {
            "standalone_module_name": [("standalone", None, interface_config, backend_config_dict)]
        }

        original_m_copy = copy.deepcopy(original_m)
        original_ref_m_copy = copy.deepcopy(original_ref_m)

        qconfig_dict = {"": qconfig}
        # check prepared model
        m = prepare_fx(
            original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config, backend_config_dict=backend_config_dict)
        # calibration
        m(data)
        self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)

        # check converted/quantized model
        m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict)
        self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
        res = m(data)

        # quantize the reference model
        ref_m = prepare_fx(original_ref_m_copy, qconfig_dict, backend_config_dict=backend_config_dict)
        ref_m(data)
        ref_m = _convert_fx_do_not_use(ref_m, is_reference=True, backend_config_dict=backend_config_dict)
        ref_res = ref_m(data)
        self.assertEqual(res, ref_res)