Esempio n. 1
0
    def test_clamp(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(2, 2, 2).float()
                self.relu6 = torch.nn.ReLU6()
                self.relu6_ = torch.nn.ReLU6(True)
                self.hardtanh = torch.nn.Hardtanh()
                self.hardtanh_ = torch.nn.Hardtanh(inplace=True)

            def forward(self, x):
                x = self.conv(x)
                x = self.relu6(x)
                self.relu6_(x)
                x = F.relu6(x)
                x = torch.clamp(x, -3, 3)
                x = x.clamp(-2.5, 2.5)
                # x = x.clamp_(-2, 2)  # Enable when quantized `clamp_` is ready
                x = self.hardtanh(x)
                self.hardtanh_(x)
                x = F.hardtanh(x)
                F.hardtanh_(x)
                return x

        data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
        # list of node that should occur in order
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_function(F.hardtanh_),
            ns.call_method('dequantize')
        ]
        for quant_type in self.static_quant_types:
            m = self.checkGraphModeFxOp(
                M(), data, quant_type, expected_node_list=node_list)
    def test_qconfig_dict_object_type_function_global_none(self):
        """
        Verifies that the 'object_type' option of qconfig_dict works
        on function types when global qconfig is None.
        """
        class M(nn.Module):
            def forward(self, x):
                x = x + x
                return x

        m = M()
        qconfig_dict = {
            '': None,
            'object_type': [
                (torch.add, torch.quantization.default_qconfig),
            ],
        }
        example_args = (torch.randn(1, 1, 1, 1),)
        mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
        mp(*example_args)
        mq = _quantize_dbr.convert(mp)
        mq(*example_args)
        rewritten = mq.rewrite_for_scripting()
        expected_occurrence = {
            NodeSpec.call_function(torch.add): 0,
            NodeSpec.call_function(toq.add): 1,
        }
        self.checkGraphModuleNodes(
            rewritten, expected_node_occurrence=expected_occurrence)
Esempio n. 3
0
 def test_standalone_module_float_interface(self):
     float_interface_config = {
         "input_quantized_idxs": [],  # float input
         "output_quantized_idxs": [],  # float output
     }
     interface_config = float_interface_config
     # input and output of first conv, observer for standalone module
     # will be inserted in the standalone module itself
     prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 2
     }
     # for input and output of conv in the standalone module
     standalone_prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 2
     }
     convert_count_check = {
         # input and output of reference conv
         ns.call_function(torch.quantize_per_tensor) : 2,
         ns.call_module(nnqr.Conv2d) : 1,
         ns.call_method("dequantize") : 2,
     }
     standalone_convert_count_check = {
         # standalone module will take float as input and output
         # so we'll see quantize and dequantize in the modoule
         ns.call_function(torch.quantize_per_tensor) : 2,
         ns.call_module(nnqr.Conv2d): 1,
         ns.call_method("dequantize") : 2,
     }
     self._test_standalone_module(
         interface_config,
         prepare_count_check,
         standalone_prepare_count_check,
         convert_count_check,
         standalone_convert_count_check)
Esempio n. 4
0
    def test_addmm(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.weight = torch.randn(5, 5)
                self.bias = torch.randn(5)

            def forward(self, x):
                return torch.addmm(self.bias, x, self.weight)

        m = M().eval()
        prepared = prepare_fx(
            m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # weight
            ns.call_module(torch.ao.quantization.MinMaxObserver): 1,
            # activation
            ns.call_module(torch.ao.quantization.HistogramObserver): 2,
        }
        self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
        quantized = _convert_fx_do_not_use(
            prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # input activation, output activation and weight
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_function(torch.addmm): 1,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
Esempio n. 5
0
    def test_functional(self):
        """ Test quantizing functional conv and linear
        """
        class Conv(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x, weight):
                return F.conv2d(x, weight, None, self.stride, self.padding,
                                self.dilation, self.groups)

        conv_input = torch.rand(1, 3, 224, 224)
        conv_weight = torch.rand(3, 3, 3, 3)

        class Linear(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, weight):
                return F.linear(x, weight)

        linear_input = torch.rand(8, 5)
        linear_weight = torch.rand(10, 5)

        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        tests = [
            (False, Conv, (conv_input, conv_weight),
             ns.call_function(torch.ops.quantized.conv2d)),
            (True, Linear, (linear_input, linear_weight),
             ns.call_function(torch.ops.quantized.linear_dynamic)),
            (False, Linear, (linear_input, linear_weight),
             ns.call_function(torch.ops.quantized.linear)),
            (True, LinearModule, (linear_module_input, ),
             ns.call_module(nnqd.Linear)),
            (False, LinearModule, (linear_module_input, ),
             ns.call_module(nnq.Linear)),
        ]

        for is_dynamic, M, inputs, quantized_node in tests:
            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
            self.checkGraphModeFxOp(M(), inputs, quant_type, quantized_node)
Esempio n. 6
0
    def test_qconfig_none(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        m = M().eval()
        m = symbolic_trace(m)
        qconfig_dict = {'': default_qconfig, 'conv2': None}
        m = prepare_static_fx(m, qconfig_dict)
        data = torch.randn(1, 1, 1, 1)
        m(data)
        m = convert_static_fx(m)
        m(data)
        # first conv is quantized, second conv is not quantized
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
            ns.call_module(nn.Conv2d),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)
Esempio n. 7
0
    def _test_activation_impl(self, float_module, float_op, quantized_module,
                              quantized_op):
        ''' Test for activation op(with inplace options), float_op can be
        torch op or functional op
        '''
        class M(torch.nn.Module):
            def __init__(self, is_module, inplace):
                super(M, self).__init__()
                self.is_module = is_module
                self.inplace = inplace
                if self.is_module:
                    self.op = float_module(self.inplace)
                else:
                    self.op = float_op

            def forward(self, input):
                if self.is_module:
                    return self.op(input)
                else:
                    return self.op(input, self.inplace)

        options = itertools.product([True, False], [True, False],
                                    self.static_quant_types)
        quantized_nodes = {
            # is_module
            True: ns.call_module(quantized_module),
            False: ns.call_function(quantized_op),
        }

        for is_module, is_inplace, quant_type in options:
            self.checkGraphModeFxOp(M(is_module, is_inplace), self.img_data_2d,
                                    quant_type, quantized_nodes[is_module])
Esempio n. 8
0
    def test_linear(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
            weight=torch.quantization.default_weight_observer)
        prepared = prepare_fx(
            m, {"": qconfig},
            backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(linear_module_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, linear_module_input,
                               [((1, *linear_module_input.shape[1:]),
                                 (5, *linear_module_input.shape[1:]),
                                 (10, *linear_module_input.shape[1:]))])
        # make sure it runs
        trt_mod(linear_module_input.cuda())
Esempio n. 9
0
    def test_unsupported_qconfig(self):
        """ Check that we won't quantize the model if the qconfig is not supported
        """
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        trt_unsupported_qconfig = default_qconfig
        prepared = prepare_fx(m, {"": trt_unsupported_qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        # calibration
        prepared(linear_module_input)
        quantized = _convert_fx_do_not_use(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 0,
            ns.call_method("dequantize"): 0,
            ns.call_module(torch.nn.Linear): 1,
            ns.call_module(torch.nn.quantized._reference.Linear): 0,
        }
        # check model is not quantized
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
Esempio n. 10
0
    def test_ops(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.linear = torch.nn.Linear(5, 5)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv(x)
                x = self.linear(x)
                x = x + 3
                x = self.relu(x)
                x = x + 6
                return x

        m = M().eval()
        m = prepare_fx(m, {"": default_qconfig})
        m = _convert_fx_do_not_use(m, is_reference=True)
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 5,
            ns.call_method("dequantize"): 5,
            ns.call_module(torch.nn.quantized._reference.Linear): 1,
            ns.call_module(torch.nn.quantized._reference.Conv2d): 1,
        }
        self.checkGraphModuleNodes(
            m, expected_node_occurrence=expected_occurrence)
Esempio n. 11
0
    def _test_norm_impl(
            self, float_module, float_op, op_args, data, quantized_module, quantized_op,
            skip_op_arg_for_functional=False):
        ''' Test for normalization op, float_op can be torch op or functional op,
        op_args is a list of positional argument for the module/op
        '''
        class M(torch.nn.Module):
            def __init__(self, is_module):
                super(M, self).__init__()
                self.is_module = is_module
                if self.is_module:
                    self.op = float_module(*op_args)
                else:
                    self.op = float_op

            def forward(self, input):
                if self.is_module:
                    return self.op(input)
                else:
                    args = [input]
                    if not skip_op_arg_for_functional:
                        args += op_args
                    return self.op(*args)

        options = itertools.product([True, False], self.static_quant_types)
        quantized_nodes = {
            # is_module
            True: ns.call_module(quantized_module),
            False: ns.call_function(quantized_op),
        }

        for is_module, quant_type in options:
            self.checkGraphModeFxOp(
                M(is_module), data, quant_type, quantized_nodes[is_module])
Esempio n. 12
0
    def test_quantized_cat(self):
        """ quantization of the output of cat will be depend on the
        input of cat. we only quantize the output of cat when its inputs are quantized.
        """
        class QuantizedCat(torch.nn.Module):
            def __init__(self):
                super(QuantizedCat, self).__init__()
                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()

            def forward(self, x, y):
                x = self.conv1(x)
                y = self.conv2(y)
                return torch.cat([x, y], 1)

        # TODO: decide whether to quantize in this case
        # class NonQuantizedCat(torch.nn.Module):
        #     def __init__(self):
        #         super(NonQuantizedCat, self).__init__()

        #     def forward(self, x, y):
        #         return torch.cat([x, y], 1)

        data = (torch.randn(1, 2, 5, 5, dtype=torch.float),
                torch.randn(1, 2, 5, 5, dtype=torch.float))
        quantized_node = ns.call_function(torch.ops.quantized.cat)
        for quant_type in self.static_quant_types:
            self.checkGraphModeFxOp(QuantizedCat(), data, quant_type,
                                    quantized_node)
Esempio n. 13
0
    def test_conv(self):
        class Conv2d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.conv = torch.nn.Conv2d(*args)

            def forward(self, x):
                return self.conv(x)

        conv2d_input = torch.rand(1, 3, 224, 224)
        conv2d_module_args = (3, 3, 3)

        m = Conv2d(*conv2d_module_args).eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
            ),
            weight=torch.quantization.default_weight_observer
        )
        prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(conv2d_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, conv2d_input, [((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))])
        # make sure it runs
        trt_mod(conv2d_input.cuda())
Esempio n. 14
0
    def test_linear_relu_module(self):
        class LinearModule(torch.nn.Module):
            def __init__(self, has_relu=False, f_relu=False):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10).float()
                if has_relu:
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()
                else:
                    self.relu = torch.nn.Identity()

            def forward(self, x):
                return self.relu(self.linear(x))

        linear_input = torch.rand(8, 5)

        shape_ranges = [((1, 5), (5, 5), (10, 5))]
        for has_relu, f_relu in itertools.product([True, False],
                                                  [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor):
                2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
            }
            self._test_module(LinearModule(has_relu, f_relu), [linear_input],
                              shape_ranges,
                              no_convert=no_convert)
Esempio n. 15
0
    def _test_quantized_binary_op_impl(self, binary_op, ibinary_op, quantized_op):
        class Op(torch.nn.Module):
            def __init__(self, is_inplace, is_scalar):
                super(Op, self).__init__()
                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
                self.is_scalar = is_scalar
                self.op = ibinary_op if is_inplace else binary_op

            def forward(self, x, y):
                x = self.conv1(x)
                y = 3 if self.is_scalar else self.conv2(y)
                x = self.op(x, y)
                return x

        # TODO: decide whether we want to quantize or not
        # in this case
        # class NonQuantizedOp(torch.nn.Module):
        #     def __init__(self, is_inplace, is_scalar):
        #         super(NonQuantizedOp, self).__init__()
        #         self.is_scalar = is_scalar
        #         self.op = ibinary_op if is_inplace else binary_op

        #     def forward(self, x, y):
        #         y = 3 if self.is_scalar else y
        #         x = self.op(x, y)
        #         return x

        data = (torch.randn(1, 2, 3, 3, dtype=torch.float),
                torch.randn(1, 2, 3, 3, dtype=torch.float))
        quantized_node = ns.call_function(quantized_op)
        options = itertools.product([True, False], [True, False], self.static_quant_types)
        for is_inplace, is_scalar, quant_type in options:
            self.checkGraphModeFxOp(
                Op(is_inplace, is_scalar), data, quant_type, quantized_node)
Esempio n. 16
0
    def test_add_shadow_loggers_multiple_dtype_casts(self):
        """
        Verifies that for nodes where the first input arg is a list,
        such as `cat`, we insert an individual dtype cast for each
        arg of the list.
        """
        class M(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                x = torch.cat([x, x, x], dim=0)
                return x

        m = M().eval()
        expected_occurrence = {
            # 3 dequantize function calls from the 3 dtype casts for [x, x, x]
            ns.call_function(torch.dequantize):
            3,
            # 1 dequantize method call for module output
            ns.call_method("dequantize"):
            1,
        }
        self._test_match_shadow_activations(
            m, (torch.randn(4, 4), ),
            prepared_expected_node_occurrence=expected_occurrence,
            results_len=1)
Esempio n. 17
0
    def _test_quantized_binary_op_relu_impl(self, binary_op, ibinary_op,
                                            quantized_op):
        class OpRelu(torch.nn.Module):
            def __init__(self, is_inplace, is_functional_relu, is_scalar):
                super(OpRelu, self).__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
                self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
                self.op = ibinary_op if is_inplace else binary_op
                self.is_functional_relu = is_functional_relu
                self.is_scalar = is_scalar
                self.relu = F.relu if self.is_functional_relu \
                    else torch.nn.ReLU()

            def forward(self, x, y):
                x = self.conv1(x)
                y = 3 if self.is_scalar else self.conv2(y)
                x = self.op(x, y)
                x = self.relu(x)
                return x

        data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
                torch.rand((1, 1, 1, 1), dtype=torch.float))
        quant_type = QuantType.STATIC
        quantized_node = ns.call_function(quantized_op)
        options = itertools.product([True, False], [True, False],
                                    [True, False])
        for is_inplace_op, is_functional_relu, is_scalar in options:
            self.checkGraphModeFxOp(
                OpRelu(is_inplace_op, is_functional_relu, is_scalar), data,
                quant_type, quantized_node)
Esempio n. 18
0
    def test_dynamic_quant_fp16(self):
        class Linear(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)

            def forward(self, x):
                return F.linear(x, self.weight)

        linear_input = torch.rand(8, 5)
        linear_weight = torch.rand(10, 5)

        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        tests = [
            (Linear, (linear_weight,), (linear_input,),
             ns.call_function(torch.ops.quantized.linear_dynamic),
             ns.call_function(torch.ops.quantized.linear_prepack_fp16)),
            (LinearModule, (), (linear_module_input,),
             ns.call_module(nnqd.Linear),
             None),
        ]
        for (ModuleClass, module_constructor_inputs,
             inputs, quantized_node, weight_prepack_node) in tests:
            for debug in [True, False]:
                node_occurrence = dict()
                if weight_prepack_node:
                    if debug:
                        node_occurrence[weight_prepack_node] = 1
                    else:
                        node_occurrence[weight_prepack_node] = 0
                m = ModuleClass(*module_constructor_inputs).eval()
                m = symbolic_trace(m)
                qconfig_dict = {"": float16_dynamic_qconfig}
                m = quantize_dynamic_fx(m, qconfig_dict, debug=debug)
                self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
Esempio n. 19
0
    def test_cat(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.cat([x, x], 1)

        m = M().eval()
        prepared = prepare_fx(m, {"": self.qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        self.assertTrue(len(dict(prepared.named_children())) == 1)
        quantized = _convert_fx_do_not_use(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_function(torch.cat): 1,
            ns.call_method("dequantize"): 2,
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
Esempio n. 20
0
 def test_fp32_input_fp32_output(self):
     prepare_custom_config_dict = {}
     prepare_count_check = {
         ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
     }
     convert_count_check = {
         ns.call_function(torch.quantize_per_tensor): 3,
         ns.call_method('dequantize'): 3,
     }
     self._test_quantized_inputs_outputs(
         prepare_custom_config_dict, prepare_count_check, convert_count_check)
Esempio n. 21
0
    def test_conv_add(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)

            def forward(self, x, y):
                return self.conv(x) + y

        weighted_op_qint8_dtype_config = {
            # optional, input activation dtype
            "input_dtype": torch.qint8,
            # optional, weight dtype
            "weight_dtype": torch.qint8,
            # optional, bias dtype
            "bias_dtype": torch.float,
            # optional, output activation dtype
            "output_dtype": torch.qint8
        }

        conv_add_config = {
            "pattern": (operator.add, torch.nn.Conv2d, MatchAllNode),
            "observation_type":
            ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_qint8_dtype_config,
            ],
            "root_module":
            torch.nn.Conv2d,
            "reference_quantized_module_for_root":
            torch.nn.quantized._reference.Conv2d,
        }

        m = M().eval()
        modified_backend_config_dict = copy.deepcopy(
            self.trt_backend_config_dict)
        modified_backend_config_dict["configs"].insert(0, conv_add_config)
        m = prepare_fx(m, {"": self.qconfig},
                       backend_config_dict=modified_backend_config_dict)
        node_occurrence = {
            ns.call_module(torch.ao.quantization.HistogramObserver): 3,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        m = _convert_fx_do_not_use(
            m,
            is_reference=True,
            backend_config_dict=modified_backend_config_dict)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
Esempio n. 22
0
    def _get_conv_linear_test_cases(self):
        ''' Returns a list of test cases, with format:
        is_dynamic, ModuleClass, module_constructor_inputs,
        inputs, quantized_node, weight_prepack_op
        '''
        class Conv(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x):
                return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)

        conv_input = torch.rand(1, 3, 224, 224)
        conv_weight = torch.rand(3, 3, 3, 3)

        class Linear(torch.nn.Module):
            def __init__(self, weight):
                super().__init__()
                self.weight = torch.nn.Parameter(weight)

            def forward(self, x):
                return F.linear(x, self.weight)

        linear_input = torch.rand(8, 5)
        linear_weight = torch.rand(10, 5)

        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        tests = [
            (False, Conv, (conv_weight,), (conv_input,),
             ns.call_function(torch.ops.quantized.conv2d),
             ns.call_function(torch.ops.quantized.conv2d_prepack)),
            (True, Linear, (linear_weight,), (linear_input,),
             ns.call_function(torch.ops.quantized.linear_dynamic),
             ns.call_function(torch.ops.quantized.linear_prepack)),
            (False, Linear, (linear_weight,), (linear_input,),
             ns.call_function(torch.ops.quantized.linear),
             ns.call_function(torch.ops.quantized.linear_prepack)),
            (True, LinearModule, (), (linear_module_input,),
             ns.call_module(nnqd.Linear),
             None),
            (False, LinearModule, (), (linear_module_input,),
             ns.call_module(nnq.Linear),
             None),
        ]
        return tests
Esempio n. 23
0
 def test_quantized_input_fp32_output(self):
     prepare_custom_config_dict = {
         'input_quantized_idxs': [0]}
     prepare_count_check = {
         ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
     }
     convert_count_check = {
         # output of conv1, conv2
         ns.call_function(torch.quantize_per_tensor): 2,
         # input of ref conv1, input of ref conv2, final output
         ns.call_method('dequantize'): 3,
     }
     self._test_quantized_inputs_outputs(
         prepare_custom_config_dict, prepare_count_check, convert_count_check)
Esempio n. 24
0
    def test_conv_relu_module(self):
        conv_module = {
            1: torch.nn.Conv1d,
            2: torch.nn.Conv2d,
            3: torch.nn.Conv3d
        }

        conv1d_input = torch.rand(1, 3, 10)
        conv2d_input = torch.rand(1, 3, 10, 10)
        conv3d_input = torch.rand(1, 3, 10, 10, 10)
        conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input}

        class ConvNdModule(torch.nn.Module):
            def __init__(self, dim, has_relu=False, f_relu=False):
                super().__init__()
                self.conv = conv_module[dim](3, 3, 3).float()
                if has_relu:
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()
                else:
                    self.relu = torch.nn.Identity()

            def forward(self, x):
                return self.relu(self.conv(x))

        # just testing conv2d since conv1d and conv3d are not supported in fx2trt
        for dim, has_relu, f_relu, is_qat in itertools.product([2],
                                                               [True, False],
                                                               [True, False],
                                                               [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor):
                2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
            }
            self._test_module(ConvNdModule(dim, has_relu,
                                           f_relu), [conv_input[dim]],
                              [((1, *conv_input[dim].shape[1:]),
                                (5, *conv_input[dim].shape[1:]),
                                (10, *conv_input[dim].shape[1:]))],
                              no_convert=no_convert,
                              is_qat=is_qat)
Esempio n. 25
0
    def test_conv(self):
        class Conv2dModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)

            def forward(self, x):
                return self.conv(x)

        conv2d_input = torch.rand(1, 3, 224, 224)
        no_convert = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_method("dequantize"): 2
        }
        self._test_module(Conv2dModule(), [conv2d_input],
                          [((1, 3, 224, 224), (5, 3, 224, 224),
                            (10, 3, 224, 224))],
                          no_convert=no_convert)
Esempio n. 26
0
    def test_linear(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_input = torch.rand(8, 5)

        shape_ranges = [((1, 5), (5, 5), (10, 5))]
        no_convert = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_method("dequantize"): 2,
        }
        self._test_module(LinearModule(), [linear_input],
                          shape_ranges,
                          no_convert=no_convert)
Esempio n. 27
0
    def test_selective_equalization(self):
        """ Tests that we are able to run numeric suite on the equalized model
        and construct a valid equalization_qconfig_dict equalizing only the top
        4 layers with the highest quantization errors.
        """

        torch.manual_seed(1)

        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5))
                self.top = torch.nn.Sequential(torch.nn.Linear(5, 5))

            def forward(self, x):
                x = self.bot(x)
                x = torch.add(x, 5)
                x = self.top(x)
                return x

        float_model = M().eval()
        # Hard coded so that the top layer has a higher quantization error
        x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957],
                          [0.8373, 0.8851, 0.8229, 0.0212, 0.8987],
                          [0.9077, 0.7538, 0.4530, 0.5772, 0.1376],
                          [0.0690, 0.9002, 0.7998, 0.2768, 0.8985],
                          [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]])

        # Quantize the float model
        prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict)
        prepared_model(x)
        quantized_model = convert_fx(copy.deepcopy(prepared_model))

        # Get the SQNR between the float and quantized model
        layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model), quantized_model, x)

        # Construct the equalization_qconfig_dict equalizing layers with the highest
        # quantization errors
        selective_equalization_qconfig_dict = get_equalization_qconfig_dict(layer_to_sqnr_dict, 1)

        # Create the selectively equalized model
        prepared_model = prepare_fx(
            copy.deepcopy(float_model),
            specific_qconfig_dict,
            equalization_qconfig_dict=selective_equalization_qconfig_dict,
        )
        prepared_model(x)
        equalized_model = convert_fx(prepared_model)

        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        # Check the order of nodes in the graph
        self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
Esempio n. 28
0
    def test_input_weight_equalization_graphs(self):
        """ Tests that the modified model for equalization has the same graph
        structure as the model without equalization (before and after
        quantization).
        """

        linear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        linearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_method('dequantize')
        ]

        linearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_method('dequantize')
        ]

        functionalLinearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        conv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        conv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        convRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_method('dequantize')
        ]

        convReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConvRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_method('dequantize')
        ]

        functionalConvReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        tests = [(SingleLayerLinearModel, linear_node_list),
                 (LinearAddModel, linearAdd_node_list),
                 (TwoLayerLinearModel, linear2_node_list),
                 (SingleLayerFunctionalLinearModel, functionalLinear_node_list),
                 (FunctionalLinearAddModel, functionalLinearAdd_node_list),
                 (TwoLayerFunctionalLinearModel, functionalLinear2_node_list),
                 (LinearReluModel, linearRelu_node_list),
                 (LinearReluLinearModel, linearReluLinear_node_list),
                 (FunctionalLinearReluModel, functionalLinearRelu_node_list),
                 (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list),
                 (ConvModel, conv_node_list),
                 (TwoLayerConvModel, conv2_node_list),
                 (SingleLayerFunctionalConvModel, functionalConv_node_list),
                 (TwoLayerFunctionalConvModel, functionalConv2_node_list),
                 (ConvReluModel, convRelu_node_list),
                 (ConvReluConvModel, convReluConv_node_list),
                 (FunctionalConvReluModel, functionalConvRelu_node_list),
                 (FunctionalConvReluConvModel, functionalConvReluConv_node_list)]

        for (M, node_list) in tests:
            m = M().eval()
            prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
            equalized_quantized_model = convert_fx(prepared)

            # Check the order of nodes in the graph
            self.checkGraphModuleNodes(equalized_quantized_model, expected_node_list=node_list)
Esempio n. 29
0
    def test_general_value_ops(self):
        """ A test that checks correct patterns are produced for
        all supported general value ops like aten::avg_pool2d \
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.avg_pool1d = torch.nn.AvgPool1d(3)
                self.avg_pool2d = torch.nn.AvgPool2d(3)
                self.avg_pool3d = torch.nn.AvgPool3d(3)
                self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
                self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d(
                    (1, 1, 1))
                self.leaky_relu = torch.nn.LeakyReLU()
                self.hardsigmoid = torch.nn.Hardsigmoid()
                self.sigmoid = torch.nn.Sigmoid()
                self.tanh = torch.nn.Tanh()

            def forward(self, x):
                x = self.conv(x)
                x = self.avg_pool1d(x)
                x = self.avg_pool2d(x)
                x = self.avg_pool3d(x)
                x = self.adaptive_avg_pool1d(x)
                x = self.adaptive_avg_pool2d(x)
                x = self.adaptive_avg_pool3d(x)
                x = F.avg_pool1d(x, 3)
                x = F.avg_pool2d(x, 3)
                x = F.avg_pool3d(x, 3)
                x = F.adaptive_avg_pool1d(x, (1))
                x = F.adaptive_avg_pool2d(x, (1, 1))
                x = F.adaptive_avg_pool3d(x, (1, 1, 1))
                x = torch.mean(x)
                x = torch.mean(x, [2, 3], False)
                x = x.mean()
                x = x.mean([2, 3], True)
                x = F.interpolate(x, 4, mode='nearest')
                x = F.interpolate(x, 4, mode='linear')
                x = self.leaky_relu(x)
                x = F.leaky_relu(x)
                x = F.leaky_relu(x, inplace=True)
                x = x.leaky_relu()
                x.leaky_relu_()
                x = self.hardsigmoid(x)
                x = F.hardsigmoid(x)
                x = F.hardsigmoid(x, inplace=True)
                x = x.hardsigmoid()
                x.hardsigmoid_()
                x = self.sigmoid(x)
                x = torch.sigmoid(x)
                # F.sigmoid is deprecated
                x = x.sigmoid()
                x.sigmoid_()
                x = self.tanh(x)
                # F.tanh is deprecated
                x = torch.tanh(x)
                x = x.tanh()
                x.tanh_()
                x = self.conv(x)
                return x

        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        qconfig_dict = {'': default_qconfig}
        prepared = prepare_fx(original, qconfig_dict)
        # not runnable
        quantized = convert_fx(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)
Esempio n. 30
0
    def test_general_shape_ops(self):
        """ A test that checks dequantize will be swapped for
        all supported general shape ops like aten::flatten
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
                self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
                self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
                self.dropout = torch.nn.Dropout()
                self.conv1 = torch.nn.Conv2d(3, 3, 3)
                self.conv2 = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv1(x)
                # add_scalar
                x = x + 3
                # mul_scalar
                x = x * 3
                # add_scalar_out
                x += 3
                # mul_scalar_out
                x *= 3
                # add_scalar_relu
                x = x + 3
                x = F.relu(x)
                # add_scalar_relu_out
                x += 3
                x = F.relu(x)
                # mul_scalar_relu
                x = x * 3
                x = F.relu(x)
                # mul_scalar_relu_out
                x *= 3
                x = F.relu(x)
                x = self.maxpool1d(x)
                x = self.maxpool2d(x)
                x = self.maxpool3d(x)
                x = torch.flatten(x)
                x = torch.max(x)
                x = torch.min(x)
                x = x.reshape([-1])
                x = x.resize_(1, 1, x.numel())
                x = x.view(-1)
                # prim::ListConstruct
                xs = [x, x]
                # prim::ListUnpack
                x, y = xs
                # prim::TupleConstruct
                xs = (x, x)
                # prim::TupleUnpack
                x, y = xs
                x = x.transpose(1, 2)
                x = x.contiguous()
                x, y = torch.chunk(x, 2)
                x = F.dropout(x)
                x = self.dropout(x)
                x, _ = torch.sort(x)
                x = x.permute(0, 2, 3, 1)
                x = x.repeat_interleave(3, 1)
                x = torch.repeat_interleave(x, 3, 1)
                x = self.relu(x)
                x = F.relu(x)
                x = F.relu(x, inplace=True)
                x = x.relu()
                x.relu_()
                x = x.squeeze(0)
                x.squeeze_(0)
                x = torch.squeeze(x, 0)
                x = x.unsqueeze(0)
                x.unsqueeze_(0)
                x = torch.unsqueeze(x, 0)
                x = x.detach()
                x.detach_()
                x = x.repeat(4, 2)
                y = []
                y.append(x)
                z = torch.stack(y, 0)
                z = [z, z]
                x, _ = z
                x = self.conv2(x)
                return x

        data = torch.rand(1, 3, 10, 10)
        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        qconfig_dict = {'': default_qconfig}
        prepared = prepare_fx(original, qconfig_dict)
        # not runnable
        quantized = convert_fx(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers and also successfully fused two quantized::conv2d
        # patterns
        # one quantize_per_tensor for input
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)