Ejemplo n.º 1
0
 def test_standalone_module_float_interface(self):
     float_interface_config = {
         "input_quantized_idxs": [],  # float input
         "output_quantized_idxs": [],  # float output
     }
     interface_config = float_interface_config
     # input and output of first conv, observer for standalone module
     # will be inserted in the standalone module itself
     prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 2
     }
     # for input and output of conv in the standalone module
     standalone_prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 2
     }
     convert_count_check = {
         # input and output of reference conv
         ns.call_function(torch.quantize_per_tensor) : 2,
         ns.call_module(nnqr.Conv2d) : 1,
         ns.call_method("dequantize") : 2,
     }
     standalone_convert_count_check = {
         # standalone module will take float as input and output
         # so we'll see quantize and dequantize in the modoule
         ns.call_function(torch.quantize_per_tensor) : 2,
         ns.call_module(nnqr.Conv2d): 1,
         ns.call_method("dequantize") : 2,
     }
     self._test_standalone_module(
         interface_config,
         prepare_count_check,
         standalone_prepare_count_check,
         convert_count_check,
         standalone_convert_count_check)
Ejemplo n.º 2
0
    def test_linear(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
            weight=torch.quantization.default_weight_observer)
        prepared = prepare_fx(
            m, {"": qconfig},
            backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(linear_module_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, linear_module_input,
                               [((1, *linear_module_input.shape[1:]),
                                 (5, *linear_module_input.shape[1:]),
                                 (10, *linear_module_input.shape[1:]))])
        # make sure it runs
        trt_mod(linear_module_input.cuda())
Ejemplo n.º 3
0
    def test_unsupported_qconfig(self):
        """ Check that we won't quantize the model if the qconfig is not supported
        """
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        trt_unsupported_qconfig = default_qconfig
        prepared = prepare_fx(m, {"": trt_unsupported_qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        # calibration
        prepared(linear_module_input)
        quantized = _convert_fx_do_not_use(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 0,
            ns.call_method("dequantize"): 0,
            ns.call_module(torch.nn.Linear): 1,
            ns.call_module(torch.nn.quantized._reference.Linear): 0,
        }
        # check model is not quantized
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
Ejemplo n.º 4
0
    def test_linear_relu_module(self):
        class LinearModule(torch.nn.Module):
            def __init__(self, has_relu=False, f_relu=False):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10).float()
                if has_relu:
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()
                else:
                    self.relu = torch.nn.Identity()

            def forward(self, x):
                return self.relu(self.linear(x))

        linear_input = torch.rand(8, 5)

        shape_ranges = [((1, 5), (5, 5), (10, 5))]
        for has_relu, f_relu in itertools.product([True, False],
                                                  [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor):
                2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
            }
            self._test_module(LinearModule(has_relu, f_relu), [linear_input],
                              shape_ranges,
                              no_convert=no_convert)
Ejemplo n.º 5
0
    def test_conv(self):
        class Conv2d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.conv = torch.nn.Conv2d(*args)

            def forward(self, x):
                return self.conv(x)

        conv2d_input = torch.rand(1, 3, 224, 224)
        conv2d_module_args = (3, 3, 3)

        m = Conv2d(*conv2d_module_args).eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
            ),
            weight=torch.quantization.default_weight_observer
        )
        prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(conv2d_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, conv2d_input, [((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))])
        # make sure it runs
        trt_mod(conv2d_input.cuda())
Ejemplo n.º 6
0
    def test_addmm(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.weight = torch.randn(5, 5)
                self.bias = torch.randn(5)

            def forward(self, x):
                return torch.addmm(self.bias, x, self.weight)

        m = M().eval()
        prepared = prepare_fx(
            m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # weight
            ns.call_module(torch.ao.quantization.MinMaxObserver): 1,
            # activation
            ns.call_module(torch.ao.quantization.HistogramObserver): 2,
        }
        self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
        quantized = _convert_fx_do_not_use(
            prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # input activation, output activation and weight
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_function(torch.addmm): 1,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
Ejemplo n.º 7
0
    def test_clamp(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(2, 2, 2).float()
                self.relu6 = torch.nn.ReLU6()
                self.relu6_ = torch.nn.ReLU6(True)
                self.hardtanh = torch.nn.Hardtanh()
                self.hardtanh_ = torch.nn.Hardtanh(inplace=True)

            def forward(self, x):
                x = self.conv(x)
                x = self.relu6(x)
                self.relu6_(x)
                x = F.relu6(x)
                x = torch.clamp(x, -3, 3)
                x = x.clamp(-2.5, 2.5)
                # x = x.clamp_(-2, 2)  # Enable when quantized `clamp_` is ready
                x = self.hardtanh(x)
                self.hardtanh_(x)
                x = F.hardtanh(x)
                F.hardtanh_(x)
                return x

        data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
        # list of node that should occur in order
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_function(F.hardtanh_),
            ns.call_method('dequantize')
        ]
        for quant_type in self.static_quant_types:
            m = self.checkGraphModeFxOp(
                M(), data, quant_type, expected_node_list=node_list)
Ejemplo n.º 8
0
    def test_add_shadow_loggers_multiple_dtype_casts(self):
        """
        Verifies that for nodes where the first input arg is a list,
        such as `cat`, we insert an individual dtype cast for each
        arg of the list.
        """
        class M(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                x = torch.cat([x, x, x], dim=0)
                return x

        m = M().eval()
        expected_occurrence = {
            # 3 dequantize function calls from the 3 dtype casts for [x, x, x]
            ns.call_function(torch.dequantize):
            3,
            # 1 dequantize method call for module output
            ns.call_method("dequantize"):
            1,
        }
        self._test_match_shadow_activations(
            m, (torch.randn(4, 4), ),
            prepared_expected_node_occurrence=expected_occurrence,
            results_len=1)
Ejemplo n.º 9
0
    def test_qconfig_none(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        m = M().eval()
        m = symbolic_trace(m)
        qconfig_dict = {'': default_qconfig, 'conv2': None}
        m = prepare_static_fx(m, qconfig_dict)
        data = torch.randn(1, 1, 1, 1)
        m(data)
        m = convert_static_fx(m)
        m(data)
        # first conv is quantized, second conv is not quantized
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
            ns.call_module(nn.Conv2d),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)
Ejemplo n.º 10
0
    def test_ops(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.linear = torch.nn.Linear(5, 5)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv(x)
                x = self.linear(x)
                x = x + 3
                x = self.relu(x)
                x = x + 6
                return x

        m = M().eval()
        m = prepare_fx(m, {"": default_qconfig})
        m = _convert_fx_do_not_use(m, is_reference=True)
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 5,
            ns.call_method("dequantize"): 5,
            ns.call_module(torch.nn.quantized._reference.Linear): 1,
            ns.call_module(torch.nn.quantized._reference.Conv2d): 1,
        }
        self.checkGraphModuleNodes(
            m, expected_node_occurrence=expected_occurrence)
Ejemplo n.º 11
0
 def test_fp32_input_fp32_output(self):
     prepare_custom_config_dict = {}
     prepare_count_check = {
         ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
     }
     convert_count_check = {
         ns.call_function(torch.quantize_per_tensor): 3,
         ns.call_method('dequantize'): 3,
     }
     self._test_quantized_inputs_outputs(
         prepare_custom_config_dict, prepare_count_check, convert_count_check)
Ejemplo n.º 12
0
    def test_conv_add(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)

            def forward(self, x, y):
                return self.conv(x) + y

        weighted_op_qint8_dtype_config = {
            # optional, input activation dtype
            "input_dtype": torch.qint8,
            # optional, weight dtype
            "weight_dtype": torch.qint8,
            # optional, bias dtype
            "bias_dtype": torch.float,
            # optional, output activation dtype
            "output_dtype": torch.qint8
        }

        conv_add_config = {
            "pattern": (operator.add, torch.nn.Conv2d, MatchAllNode),
            "observation_type":
            ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_qint8_dtype_config,
            ],
            "root_module":
            torch.nn.Conv2d,
            "reference_quantized_module_for_root":
            torch.nn.quantized._reference.Conv2d,
        }

        m = M().eval()
        modified_backend_config_dict = copy.deepcopy(
            self.trt_backend_config_dict)
        modified_backend_config_dict["configs"].insert(0, conv_add_config)
        m = prepare_fx(m, {"": self.qconfig},
                       backend_config_dict=modified_backend_config_dict)
        node_occurrence = {
            ns.call_module(torch.ao.quantization.HistogramObserver): 3,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        m = _convert_fx_do_not_use(
            m,
            is_reference=True,
            backend_config_dict=modified_backend_config_dict)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
Ejemplo n.º 13
0
 def test_quantized_input_fp32_output(self):
     prepare_custom_config_dict = {
         'input_quantized_idxs': [0]}
     prepare_count_check = {
         ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
     }
     convert_count_check = {
         # output of conv1, conv2
         ns.call_function(torch.quantize_per_tensor): 2,
         # input of ref conv1, input of ref conv2, final output
         ns.call_method('dequantize'): 3,
     }
     self._test_quantized_inputs_outputs(
         prepare_custom_config_dict, prepare_count_check, convert_count_check)
Ejemplo n.º 14
0
    def test_conv_relu_module(self):
        conv_module = {
            1: torch.nn.Conv1d,
            2: torch.nn.Conv2d,
            3: torch.nn.Conv3d
        }

        conv1d_input = torch.rand(1, 3, 10)
        conv2d_input = torch.rand(1, 3, 10, 10)
        conv3d_input = torch.rand(1, 3, 10, 10, 10)
        conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input}

        class ConvNdModule(torch.nn.Module):
            def __init__(self, dim, has_relu=False, f_relu=False):
                super().__init__()
                self.conv = conv_module[dim](3, 3, 3).float()
                if has_relu:
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()
                else:
                    self.relu = torch.nn.Identity()

            def forward(self, x):
                return self.relu(self.conv(x))

        # just testing conv2d since conv1d and conv3d are not supported in fx2trt
        for dim, has_relu, f_relu, is_qat in itertools.product([2],
                                                               [True, False],
                                                               [True, False],
                                                               [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor):
                2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
            }
            self._test_module(ConvNdModule(dim, has_relu,
                                           f_relu), [conv_input[dim]],
                              [((1, *conv_input[dim].shape[1:]),
                                (5, *conv_input[dim].shape[1:]),
                                (10, *conv_input[dim].shape[1:]))],
                              no_convert=no_convert,
                              is_qat=is_qat)
Ejemplo n.º 15
0
    def test_conv(self):
        class Conv2dModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)

            def forward(self, x):
                return self.conv(x)

        conv2d_input = torch.rand(1, 3, 224, 224)
        no_convert = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_method("dequantize"): 2
        }
        self._test_module(Conv2dModule(), [conv2d_input],
                          [((1, 3, 224, 224), (5, 3, 224, 224),
                            (10, 3, 224, 224))],
                          no_convert=no_convert)
Ejemplo n.º 16
0
    def test_linear(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_input = torch.rand(8, 5)

        shape_ranges = [((1, 5), (5, 5), (10, 5))]
        no_convert = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_method("dequantize"): 2,
        }
        self._test_module(LinearModule(), [linear_input],
                          shape_ranges,
                          no_convert=no_convert)
Ejemplo n.º 17
0
    def test_cat(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.cat([x, x], 1)

        m = M().eval()
        prepared = prepare_fx(m, {"": self.qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        self.assertTrue(len(dict(prepared.named_children())) == 1)
        quantized = _convert_fx_do_not_use(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_function(torch.cat): 1,
            ns.call_method("dequantize"): 2,
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
Ejemplo n.º 18
0
    def test_general_value_ops(self):
        """ A test that checks correct patterns are produced for
        all supported general value ops like aten::avg_pool2d \
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.avg_pool1d = torch.nn.AvgPool1d(3)
                self.avg_pool2d = torch.nn.AvgPool2d(3)
                self.avg_pool3d = torch.nn.AvgPool3d(3)
                self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
                self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d(
                    (1, 1, 1))
                self.leaky_relu = torch.nn.LeakyReLU()
                self.hardsigmoid = torch.nn.Hardsigmoid()
                self.sigmoid = torch.nn.Sigmoid()
                self.tanh = torch.nn.Tanh()

            def forward(self, x):
                x = self.conv(x)
                x = self.avg_pool1d(x)
                x = self.avg_pool2d(x)
                x = self.avg_pool3d(x)
                x = self.adaptive_avg_pool1d(x)
                x = self.adaptive_avg_pool2d(x)
                x = self.adaptive_avg_pool3d(x)
                x = F.avg_pool1d(x, 3)
                x = F.avg_pool2d(x, 3)
                x = F.avg_pool3d(x, 3)
                x = F.adaptive_avg_pool1d(x, (1))
                x = F.adaptive_avg_pool2d(x, (1, 1))
                x = F.adaptive_avg_pool3d(x, (1, 1, 1))
                x = torch.mean(x)
                x = torch.mean(x, [2, 3], False)
                x = x.mean()
                x = x.mean([2, 3], True)
                x = F.interpolate(x, 4, mode='nearest')
                x = F.interpolate(x, 4, mode='linear')
                x = self.leaky_relu(x)
                x = F.leaky_relu(x)
                x = F.leaky_relu(x, inplace=True)
                x = x.leaky_relu()
                x.leaky_relu_()
                x = self.hardsigmoid(x)
                x = F.hardsigmoid(x)
                x = F.hardsigmoid(x, inplace=True)
                x = x.hardsigmoid()
                x.hardsigmoid_()
                x = self.sigmoid(x)
                x = torch.sigmoid(x)
                # F.sigmoid is deprecated
                x = x.sigmoid()
                x.sigmoid_()
                x = self.tanh(x)
                # F.tanh is deprecated
                x = torch.tanh(x)
                x = x.tanh()
                x.tanh_()
                x = self.conv(x)
                return x

        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        qconfig_dict = {'': default_qconfig}
        prepared = prepare_fx(original, qconfig_dict)
        # not runnable
        quantized = convert_fx(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)
Ejemplo n.º 19
0
    def test_general_shape_ops(self):
        """ A test that checks dequantize will be swapped for
        all supported general shape ops like aten::flatten
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
                self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
                self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
                self.dropout = torch.nn.Dropout()
                self.conv1 = torch.nn.Conv2d(3, 3, 3)
                self.conv2 = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv1(x)
                # add_scalar
                x = x + 3
                # mul_scalar
                x = x * 3
                # add_scalar_out
                x += 3
                # mul_scalar_out
                x *= 3
                # add_scalar_relu
                x = x + 3
                x = F.relu(x)
                # add_scalar_relu_out
                x += 3
                x = F.relu(x)
                # mul_scalar_relu
                x = x * 3
                x = F.relu(x)
                # mul_scalar_relu_out
                x *= 3
                x = F.relu(x)
                x = self.maxpool1d(x)
                x = self.maxpool2d(x)
                x = self.maxpool3d(x)
                x = torch.flatten(x)
                x = torch.max(x)
                x = torch.min(x)
                x = x.reshape([-1])
                x = x.resize_(1, 1, x.numel())
                x = x.view(-1)
                # prim::ListConstruct
                xs = [x, x]
                # prim::ListUnpack
                x, y = xs
                # prim::TupleConstruct
                xs = (x, x)
                # prim::TupleUnpack
                x, y = xs
                x = x.transpose(1, 2)
                x = x.contiguous()
                x, y = torch.chunk(x, 2)
                x = F.dropout(x)
                x = self.dropout(x)
                x, _ = torch.sort(x)
                x = x.permute(0, 2, 3, 1)
                x = x.repeat_interleave(3, 1)
                x = torch.repeat_interleave(x, 3, 1)
                x = self.relu(x)
                x = F.relu(x)
                x = F.relu(x, inplace=True)
                x = x.relu()
                x.relu_()
                x = x.squeeze(0)
                x.squeeze_(0)
                x = torch.squeeze(x, 0)
                x = x.unsqueeze(0)
                x.unsqueeze_(0)
                x = torch.unsqueeze(x, 0)
                x = x.detach()
                x.detach_()
                x = x.repeat(4, 2)
                y = []
                y.append(x)
                z = torch.stack(y, 0)
                z = [z, z]
                x, _ = z
                x = self.conv2(x)
                return x

        data = torch.rand(1, 3, 10, 10)
        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        qconfig_dict = {'': default_qconfig}
        prepared = prepare_fx(original, qconfig_dict)
        # not runnable
        quantized = convert_fx(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers and also successfully fused two quantized::conv2d
        # patterns
        # one quantize_per_tensor for input
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)
Ejemplo n.º 20
0
    def test_input_weight_equalization_graphs(self):
        """ Tests that the modified model for equalization has the same graph
        structure as the model without equalization (before and after
        quantization).
        """

        linear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        linearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_method('dequantize')
        ]

        linearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_method('dequantize')
        ]

        functionalLinearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        conv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        conv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        convRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_method('dequantize')
        ]

        convReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConvRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_method('dequantize')
        ]

        functionalConvReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        tests = [(SingleLayerLinearModel, linear_node_list),
                 (LinearAddModel, linearAdd_node_list),
                 (TwoLayerLinearModel, linear2_node_list),
                 (SingleLayerFunctionalLinearModel, functionalLinear_node_list),
                 (FunctionalLinearAddModel, functionalLinearAdd_node_list),
                 (TwoLayerFunctionalLinearModel, functionalLinear2_node_list),
                 (LinearReluModel, linearRelu_node_list),
                 (LinearReluLinearModel, linearReluLinear_node_list),
                 (FunctionalLinearReluModel, functionalLinearRelu_node_list),
                 (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list),
                 (ConvModel, conv_node_list),
                 (TwoLayerConvModel, conv2_node_list),
                 (SingleLayerFunctionalConvModel, functionalConv_node_list),
                 (TwoLayerFunctionalConvModel, functionalConv2_node_list),
                 (ConvReluModel, convRelu_node_list),
                 (ConvReluConvModel, convReluConv_node_list),
                 (FunctionalConvReluModel, functionalConvRelu_node_list),
                 (FunctionalConvReluConvModel, functionalConvReluConv_node_list)]

        for (M, node_list) in tests:
            m = M().eval()
            prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
            equalized_quantized_model = convert_fx(prepared)

            # Check the order of nodes in the graph
            self.checkGraphModuleNodes(equalized_quantized_model, expected_node_list=node_list)
Ejemplo n.º 21
0
    def test_conv_add_standalone_module(self):
        class Standalone(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x, y):
                return self.relu(self.conv(x) + y)

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.standalone = Standalone()

            def forward(self, x, y):
                y = self.conv(x)
                return self.standalone(x, y)

        from torch.ao.quantization.fx.backend_config_dict.observation_type import ObservationType
        weighted_op_quint8_dtype_config = {
            # optional, input activation dtype
            # TODO: change back to torch.qint8 after input_quantized_idxs and output_quantized_idxs
            # are more flexible
            "input_dtype": torch.quint8,
            # optional, weight dtype
            "weight_dtype": torch.qint8,
            # optional, bias dtype
            "bias_dtype": torch.float,
            # optional, output activation dtype
            "output_dtype": torch.quint8
        }

        conv_add_config = {
            "pattern": (torch.nn.ReLU, (operator.add, torch.nn.Conv2d, MatchAllNode)),
            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_quint8_dtype_config,
            ],
            "root_module": torch.nn.Conv2d,
            # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
        }

        conv_config = {
            "pattern": torch.nn.Conv2d,
            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_quint8_dtype_config,
            ],
            "root_module": torch.nn.Conv2d,
            # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
        }

        m = M().eval()
        backend_config_dict = {
            "configs": [
                conv_add_config,
                conv_config,
            ]
        }
        prepare_custom_config_dict = {
            "standalone_module_name": [("standalone", None, {"input_quantized_idxs": [0, 1]}, None)]
        }
        # TODO: use self.qconfig after input_quantized_idxs and output_quantized_idxs
        # are more flexible
        qconfig = torch.ao.quantization.QConfig(
            activation=torch.ao.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
            ),
            weight=torch.ao.quantization.default_weight_observer
        )
        m = prepare_fx(
            m,
            {"": qconfig},
            prepare_custom_config_dict=prepare_custom_config_dict,
            backend_config_dict=backend_config_dict)
        node_occurrence = {
            # for input and output of conv, where input is used twice, once in conv and
            # once in standalone module
            ns.call_module(torch.ao.quantization.HistogramObserver): 2,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        standalone_node_occurrence = {
            # output of the standalone module
            ns.call_module(torch.ao.quantization.HistogramObserver): 1,
        }
        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
        m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_module(nn.Conv2d): 1,
            ns.call_method("dequantize"): 1,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        standalone_node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_module(nn.Conv2d): 1,
            ns.call_module(torch.nn.ReLU): 1,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
Ejemplo n.º 22
0
    def test_selective_equalization(self):
        """ Tests that we are able to run numeric suite on the equalized model
        and construct a valid equalization_qconfig_dict equalizing only the top
        4 layers with the highest quantization errors.
        """

        torch.manual_seed(1)

        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5))
                self.top = torch.nn.Sequential(torch.nn.Linear(5, 5))

            def forward(self, x):
                x = self.bot(x)
                x = torch.add(x, 5)
                x = self.top(x)
                return x

        float_model = M().eval()
        # Hard coded so that the top layer has a higher quantization error
        x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957],
                          [0.8373, 0.8851, 0.8229, 0.0212, 0.8987],
                          [0.9077, 0.7538, 0.4530, 0.5772, 0.1376],
                          [0.0690, 0.9002, 0.7998, 0.2768, 0.8985],
                          [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]])

        # Quantize the float model
        prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict)
        prepared_model(x)
        quantized_model = convert_fx(copy.deepcopy(prepared_model))

        # Get the SQNR between the float and quantized model
        layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model), quantized_model, x)

        # Construct the equalization_qconfig_dict equalizing layers with the highest
        # quantization errors
        selective_equalization_qconfig_dict = get_equalization_qconfig_dict(layer_to_sqnr_dict, 1)

        # Create the selectively equalized model
        prepared_model = prepare_fx(
            copy.deepcopy(float_model),
            specific_qconfig_dict,
            equalization_qconfig_dict=selective_equalization_qconfig_dict,
        )
        prepared_model(x)
        equalized_model = convert_fx(prepared_model)

        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        # Check the order of nodes in the graph
        self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
Ejemplo n.º 23
0
 def test_standalone_module_quantized_interface(self):
     quantized_interface_config = {
         "input_quantized_idxs": [0],  # quantized input
         "output_quantized_idxs": [0],  # quantized output
     }
     interface_config = quantized_interface_config
     # TODO: input_quantized_idxs only supports quint8, we can remove this
     # custom_backend_config_dict after
     # the `input_quantized_idxs` supports more complicated
     # configurations, as a first step we can change it to use a dictionary from
     # index to dtype
     qconfig = torch.ao.quantization.QConfig(
         activation=torch.ao.quantization.observer.HistogramObserver.with_args(
             qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
         ),
         weight=torch.ao.quantization.default_weight_observer
     )
     weighted_op_quint8_dtype_config = {
         # optional, input activation dtype
         "input_dtype": torch.quint8,
         # optional, weight dtype
         "weight_dtype": torch.qint8,
         # optional, bias dtype
         "bias_dtype": torch.float,
         # optional, output activation dtype
         "output_dtype": torch.quint8
     }
     conv_module_config = {
         "pattern": torch.nn.Conv2d,
         "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
         "dtype_configs": [
             weighted_op_quint8_dtype_config,
         ],
         "root_module": torch.nn.Conv2d,
         "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
     }
     custom_backend_config_dict = {
         "configs": [conv_module_config]
     }
     # observer for input and output of first conv
     prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 2
     }
     # for output of conv in the standalone module
     standalone_prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 1
     }
     convert_count_check = {
         # quantizing input/output for reference conv
         ns.call_function(torch.quantize_per_tensor) : 2,
         ns.call_module(nnqr.Conv2d) : 1,
         # dequantize the input of reference conv and
         # dequantizing output of standalone module
         ns.call_method("dequantize") : 2,
     }
     standalone_convert_count_check = {
         # quantization of input happens in parent module
         # quantization of output happens in the standalone module
         ns.call_function(torch.quantize_per_tensor) : 1,
         ns.call_module(nnqr.Conv2d): 1,
         # dequantization of input happens in the standalone module
         # dequantization for output happens in parent module
         ns.call_method("dequantize") : 1,
     }
     self._test_standalone_module(
         interface_config,
         prepare_count_check,
         standalone_prepare_count_check,
         convert_count_check,
         standalone_convert_count_check,
         qconfig=qconfig,
         backend_config_dict=custom_backend_config_dict)