예제 #1
0
    def test_qbatch_norm_relu(self):
        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}

        class BNRelu(torch.nn.Module):
            def __init__(self, dim, inplace):
                super(BNRelu, self).__init__()
                self.bn = bn_module[dim](3).to(torch.float)
                self.relu = torch.nn.ReLU(inplace=inplace)

            def forward(self, x):
                return self.relu(self.bn(x))

        class BNFuncRelu(torch.nn.Module):
            def __init__(self, dim):
                super(BNFuncRelu, self).__init__()
                self.bn = bn_module[dim](3).to(torch.float)

            def forward(self, x):
                return F.relu(self.bn(x), False)

        class BNFuncInplaceRelu(torch.nn.Module):
            def __init__(self, dim):
                super(BNFuncInplaceRelu, self).__init__()
                self.bn = bn_module[dim](3).to(torch.float)

            def forward(self, x):
                return F.relu(self.bn(x), True)

        options = itertools.product(self.static_quant_types, [2, 3])
        quantized_nodes = {
            2: ns.call_module(nniq.BNReLU2d),
            3: ns.call_module(nniq.BNReLU3d),
        }
        for quant_type, dim in options:
            for instance in [BNRelu(dim, True), BNRelu(dim, False),
                             BNFuncRelu(dim), BNFuncInplaceRelu(dim)]:
                self.checkGraphModeFxOp(
                    instance, self.img_data_dict[dim], quant_type,
                    quantized_nodes[dim])
예제 #2
0
    def _test_norm_impl(self,
                        float_module,
                        float_op,
                        op_args,
                        data,
                        quantized_module,
                        quantized_op,
                        skip_op_arg_for_functional=False):
        ''' Test for normalization op, float_op can be torch op or functional op,
        op_args is a list of positional argument for the module/op
        '''
        class M(torch.nn.Module):
            def __init__(self, is_module):
                super(M, self).__init__()
                self.is_module = is_module
                if self.is_module:
                    self.op = float_module(*op_args)
                else:
                    self.op = float_op

            def forward(self, input):
                if self.is_module:
                    return self.op(input)
                else:
                    args = [input]
                    if not skip_op_arg_for_functional:
                        args += op_args
                    return self.op(*args)

        options = itertools.product([True, False], self.static_quant_types)
        quantized_nodes = {
            # is_module
            True: ns.call_module(quantized_module),
            False: ns.call_function(quantized_op),
        }

        for is_module, quant_type in options:
            self.checkGraphModeFxOp(M(is_module), data, quant_type,
                                    quantized_nodes[is_module])
예제 #3
0
 def test_match_activations_mod(self):
     m = nn.Sequential(
         torch.quantization.QuantStub(),
         nn.Conv2d(1, 1, 1),
         nn.Conv2d(1, 1, 1),
     ).eval()
     expected_occurrence = {
         ns.call_module(OutputLogger): 2,
     }
     self._test_match_activations(
         m, (torch.randn(2, 1, 2, 2),),
         prepared_expected_node_occurrence=expected_occurrence,
         results_len=2)
예제 #4
0
    def test_quantized_conv(self):
        conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}

        class Conv(torch.nn.Module):
            def __init__(self, dim):
                super(Conv, self).__init__()
                self.conv = conv_module[dim](3, 3, 3).float()

            def forward(self, x):
                return self.conv(x)

        options = itertools.product([1, 2, 3], self.static_quant_types)
        quantized_nodes = {
            # dim
            1: ns.call_module(nnq.Conv1d),
            2: ns.call_module(nnq.Conv2d),
            3: ns.call_module(nnq.Conv3d),
        }
        for dim, quant_type in options:
            model = self.checkGraphModeFxOp(
                Conv(dim), self.img_data_dict[dim], quant_type,
                quantized_nodes[dim])
예제 #5
0
    def test_match_activations_fun(self):
        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.w1 = nn.Parameter(torch.Tensor(4, 4))
                self.b1 = nn.Parameter(torch.zeros(4))
                self.w2 = nn.Parameter(torch.Tensor(4, 4))
                self.b2 = nn.Parameter(torch.zeros(4))
                torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
                torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))

            def forward(self, x):
                x = F.linear(x, self.w1, self.b1)
                x = F.linear(x, self.w2, self.b2)
                x = F.relu(x)
                return x

        m = M().eval()
        mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
        mp(torch.randn(4, 4))
        # TODO(future PR): prevent the need for copying here, we can copy the
        # modules but should reuse the underlying tensors
        mp_copy = copy.deepcopy(mp)
        mq = convert_fx(mp_copy)

        mp_ns, mq_ns = prepare_model_outputs('fp32_prepared', mp, 'int8', mq,
                                             OutputLogger)

        expected_occurrence = {
            ns.call_module(OutputLogger): 2,
        }
        self.checkGraphModuleNodes(
            mp_ns, expected_node_occurrence=expected_occurrence)
        self.checkGraphModuleNodes(
            mq_ns, expected_node_occurrence=expected_occurrence)

        # TODO(before land): test both scripted and non-scripted
        mp_ns = torch.jit.script(mp_ns)
        mq_ns = torch.jit.script(mq_ns)

        # calibrate
        input_fp32 = torch.randn(4, 4)
        mp_ns(input_fp32)
        mq_ns(input_fp32)

        # check activation result correctness
        act_compare_dict = get_matching_activations('fp32_prepared', mp_ns,
                                                    'int8', mq_ns,
                                                    OutputLogger)
        self.assertTrue(len(act_compare_dict) == 2)
        self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
예제 #6
0
    def test_addmm(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.weight = torch.randn(5, 5)
                self.bias = torch.randn(5)

            def forward(self, x):
                return torch.addmm(self.bias, x, self.weight)

        m = M().eval()
        prepared = prepare_fx(m, {"": self.qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # weight
            ns.call_module(torch.ao.quantization.MinMaxObserver):
            1,
            # activation
            ns.call_module(torch.ao.quantization.HistogramObserver):
            2,
        }
        self.checkGraphModuleNodes(prepared,
                                   expected_node_occurrence=node_occurrence)
        quantized = _convert_fx_do_not_use(
            prepared,
            is_reference=True,
            backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # input activation, output activation and weight
            ns.call_function(torch.quantize_per_tensor):
            3,
            ns.call_function(torch.addmm):
            1,
            ns.call_method("dequantize"):
            3,
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
예제 #7
0
    def test_conv_relu_module(self):
        conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}

        conv1d_input = torch.rand(1, 3, 10)
        conv2d_input = torch.rand(1, 3, 10, 10)
        conv3d_input = torch.rand(1, 3, 10, 10, 10)
        conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input}

        class ConvNdModule(torch.nn.Module):
            def __init__(self, dim, has_relu=False, f_relu=False):
                super().__init__()
                self.conv = conv_module[dim](3, 3, 3).float()
                if has_relu:
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()
                else:
                    self.relu = torch.nn.Identity()

            def forward(self, x):
                return self.relu(self.conv(x))

        # just testing conv2d since conv1d and conv3d are not supported in fx2trt
        for dim, has_relu, f_relu in itertools.product([2], [True, False], [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
            }
            self._test_module(
                ConvNdModule(dim, has_relu, f_relu),
                [conv_input[dim]],
                [((1, *conv_input[dim].shape[1:]),
                  (5, *conv_input[dim].shape[1:]),
                  (10, *conv_input[dim].shape[1:]))],
                no_convert=no_convert)
예제 #8
0
 def test_linear_fp16_shadow_activations(self):
     for should_log_inputs in (True, False):
         qconfig_dict = {'': torch.quantization.float16_static_qconfig}
         m = LinearReluFunctional().eval()
         num_loggers = 4 if should_log_inputs else 2
         expected_occurrence = {
             ns.call_module(OutputLogger): num_loggers,
         }
         res2 = self._test_match_shadow_activations(
             m, (torch.randn(4, 4), ),
             prepared_expected_node_occurrence=expected_occurrence,
             results_len=1,
             qconfig_dict=qconfig_dict,
             should_log_inputs=should_log_inputs)
예제 #9
0
    def test_cat(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.cat([x, x], 1)

        m = M().eval()
        prepared = prepare_fx(m, {"": self.qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        self.assertTrue(len(dict(prepared.named_children())) == 1)
        quantized = _convert_fx_do_not_use(
            prepared,
            is_reference=True,
            backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_function(torch.cat): 1,
            ns.call_method("dequantize"): 2,
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
예제 #10
0
    def test_clamp(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(2, 2, 2).float()
                self.relu6 = torch.nn.ReLU6()
                self.relu6_ = torch.nn.ReLU6(True)
                self.hardtanh = torch.nn.Hardtanh()
                self.hardtanh_ = torch.nn.Hardtanh(inplace=True)

            def forward(self, x):
                x = self.conv(x)
                x = self.relu6(x)
                self.relu6_(x)
                x = F.relu6(x)
                x = torch.clamp(x, -3, 3)
                x = x.clamp(-2.5, 2.5)
                # x = x.clamp_(-2, 2)  # Enable when quantized `clamp_` is ready
                x = self.hardtanh(x)
                self.hardtanh_(x)
                x = F.hardtanh(x)
                F.hardtanh_(x)
                return x

        data = (torch.rand((1, 2, 5, 5), dtype=torch.float), )
        # list of node that should occur in order
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_function(F.hardtanh_),
            ns.call_method('dequantize')
        ]
        for quant_type in self.static_quant_types:
            m = self.checkGraphModeFxOp(M(),
                                        data,
                                        quant_type,
                                        expected_node_list=node_list)
예제 #11
0
    def test_embedding(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)

            def forward(self, indices):
                return self.emb(indices)

        model = M().eval()
        indices = torch.randint(low=0, high=10, size=(20,))

        quantized_node = ns.call_module(nnq.Embedding)
        configs = [
            (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)),
            (None, ns.call_module(nn.Embedding)),
            (default_qconfig, ns.call_module(nn.Embedding)),
        ]

        for qconfig, node in configs:
            qconfig_dict = {"": qconfig}
            m = prepare_fx(model, qconfig_dict)
            m = convert_fx(m)
            self._compare_script_and_mobile(m, input=indices)
예제 #12
0
    def test_linear_relu_module(self):
        class LinearModule(torch.nn.Module):
            def __init__(self, has_relu=False, f_relu=False):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10).float()
                if has_relu:
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()
                else:
                    self.relu = torch.nn.Identity()

            def forward(self, x):
                return self.relu(self.linear(x))

        linear_input = torch.rand(8, 5)

        shape_ranges = [
            ((1, 5),
             (5, 5),
             (10, 5))
        ]
        for has_relu, f_relu, is_qat in itertools.product([True, False], [True, False], [True, False]):
            # when has_relu=False, we have torch.nn.Identity, which would introduce
            # extra quant-dequat pair
            no_convert = {
                ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
                ns.call_method("dequantize"): 2 + int(not has_relu),
            }
            self._test_module(
                LinearModule(has_relu, f_relu),
                [linear_input],
                shape_ranges,
                no_convert=no_convert,
                is_qat=is_qat)
예제 #13
0
    def test_linear(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
            ),
            weight=torch.quantization.default_weight_observer
        )
        prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(linear_module_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(
            quantized,
            linear_module_input,
            [((1, *linear_module_input.shape[1:]),
              (5, *linear_module_input.shape[1:]),
              (10, *linear_module_input.shape[1:]))])
        # make sure it runs
        trt_mod(linear_module_input.cuda())
예제 #14
0
 def _test_match_activations_fun_impl(self, prepare_fn=prepare_fx):
     m = LinearReluLinearFunctional().eval()
     qconfig_dict = None
     if prepare_fn == prepare_qat_fx:
         qconfig_dict = {
             '': torch.quantization.get_default_qat_qconfig('fbgemm')
         }
     expected_occurrence = {
         ns.call_module(OutputLogger): 2,
     }
     self._test_match_activations(
         m, (torch.randn(4, 4), ),
         prepared_expected_node_occurrence=expected_occurrence,
         results_len=2,
         prepare_fn=prepare_fn,
         qconfig_dict=qconfig_dict)
예제 #15
0
 def _test_match_activations_mod_impl(self, prepare_fn=prepare_fx):
     m = nn.Sequential(
         torch.quantization.QuantStub(),
         nn.Conv2d(1, 1, 1),
         nn.Conv2d(1, 1, 1),
     ).eval()
     qconfig_dict = None
     if prepare_fn == prepare_qat_fx:
         qconfig_dict = {
             '': torch.quantization.get_default_qat_qconfig('fbgemm')
         }
     expected_occurrence = {
         ns.call_module(OutputLogger): 2,
     }
     self._test_match_activations(
         m, (torch.randn(2, 1, 2, 2), ),
         prepared_expected_node_occurrence=expected_occurrence,
         results_len=2,
         qconfig_dict=qconfig_dict,
         prepare_fn=prepare_fn)
예제 #16
0
 def test_standalone_module_float_interface(self):
     float_interface_config = {
         "input_quantized_idxs": [],  # float input
         "output_quantized_idxs": [],  # float output
     }
     interface_config = float_interface_config
     # input and output of first conv, observer for standalone module
     # will be inserted in the standalone module itself
     prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 2
     }
     # for input and output of conv in the standalone module
     standalone_prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 2
     }
     convert_count_check = {
         # input and output of reference conv
         ns.call_function(torch.quantize_per_tensor):
         2,
         ns.call_module(nnqr.Conv2d):
         1,
         ns.call_method("dequantize"):
         2,
     }
     standalone_convert_count_check = {
         # standalone module will take float as input and output
         # so we'll see quantize and dequantize in the modoule
         ns.call_function(torch.quantize_per_tensor):
         2,
         ns.call_module(nnqr.Conv2d):
         1,
         ns.call_method("dequantize"):
         2,
     }
     self._test_standalone_module(interface_config, prepare_count_check,
                                  standalone_prepare_count_check,
                                  convert_count_check,
                                  standalone_convert_count_check)
예제 #17
0
    def test_match_activations_mod(self):
        m = nn.Sequential(
            torch.quantization.QuantStub(),
            nn.Conv2d(1, 1, 1),
            nn.Conv2d(1, 1, 1),
        ).eval()
        mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
        mp(torch.randn(2, 1, 2, 2))
        # TODO(future PR): prevent the need for copying here, we can copy the
        # modules but should reuse the underlying tensors
        mp_copy = copy.deepcopy(mp)
        mq = convert_fx(mp_copy)

        mp_ns, mq_ns = prepare_model_outputs('fp32_prepared', mp, 'int8', mq,
                                             OutputLogger)

        expected_occurrence = {
            ns.call_module(OutputLogger): 2,
        }
        self.checkGraphModuleNodes(
            mp_ns, expected_node_occurrence=expected_occurrence)
        self.checkGraphModuleNodes(
            mq_ns, expected_node_occurrence=expected_occurrence)

        # TODO(before land): test both scripted and non-scripted
        mp_ns = torch.jit.script(mp_ns)
        mq_ns = torch.jit.script(mq_ns)

        # calibrate
        input_fp32 = torch.randn(2, 1, 2, 2)
        mp_ns(input_fp32)
        mq_ns(input_fp32)

        # check activation result correctness
        act_compare_dict = get_matching_activations('fp32_prepared', mp_ns,
                                                    'int8', mq_ns,
                                                    OutputLogger)
        self.assertTrue(len(act_compare_dict) == 2)
        self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
예제 #18
0
    def _test_quantized_binary_op_impl(self, binary_op, ibinary_op,
                                       quantized_op):
        class Op(torch.nn.Module):
            def __init__(self, is_inplace, is_scalar):
                super(Op, self).__init__()
                self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
                self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
                self.is_scalar = is_scalar
                self.op = ibinary_op if is_inplace else binary_op

            def forward(self, x, y):
                x = self.conv1(x)
                y = 3 if self.is_scalar else self.conv2(y)
                x = self.op(x, y)
                return x

        # TODO: decide whether we want to quantize or not
        # in this case
        # class NonQuantizedOp(torch.nn.Module):
        #     def __init__(self, is_inplace, is_scalar):
        #         super(NonQuantizedOp, self).__init__()
        #         self.is_scalar = is_scalar
        #         self.op = ibinary_op if is_inplace else binary_op

        #     def forward(self, x, y):
        #         y = 3 if self.is_scalar else y
        #         x = self.op(x, y)
        #         return x

        data = (torch.randn(1, 1, 1, 1, dtype=torch.float),
                torch.randn(1, 1, 1, 1, dtype=torch.float))
        quantized_node = ns.call_function(quantized_op)
        options = itertools.product([True, False], [True, False])
        quant_type = QuantType.STATIC
        for is_inplace, is_scalar in options:
            self.checkGraphModeFxOp(Op(is_inplace, is_scalar), data,
                                    quant_type, quantized_node)
예제 #19
0
    def _test_quantized_binary_op_relu_impl(self, binary_op, ibinary_op,
                                            quantized_op):
        class OpRelu(torch.nn.Module):
            def __init__(self, is_inplace, is_functional_relu, is_inplace_relu,
                         is_scalar):
                super(OpRelu, self).__init__()
                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
                self.op = ibinary_op if is_inplace else binary_op
                self.is_functional_relu = is_functional_relu
                self.is_inplace_relu = is_inplace_relu
                self.is_scalar = is_scalar

                if self.is_functional_relu:
                    self.relu = F.relu
                else:
                    self.relu = torch.nn.ReLU(self.is_inplace_relu)

            def forward(self, x, y):
                x = self.conv1(x)
                y = 3 if self.is_scalar else self.conv2(y)
                x = self.op(x, y)
                x = self.relu(x, self.is_inplace_relu) if \
                    self.is_functional_relu else self.relu(x)
                return x

        data = (torch.rand((1, 2, 5, 5), dtype=torch.float),
                torch.rand((1, 2, 5, 5), dtype=torch.float))
        quantized_node = ns.call_function(quantized_op)
        options = itertools.product([True, False], [True, False],
                                    [True, False], [True, False],
                                    self.static_quant_types)
        for is_inplace_op, is_functional_relu, is_inplace_relu, is_scalar, quant_type in options:
            self.checkGraphModeFxOp(
                OpRelu(is_inplace_op, is_functional_relu, is_inplace_relu,
                       is_scalar), data, quant_type, quantized_node)
예제 #20
0
    def test_selective_equalization(self):
        """ Tests that we are able to run numeric suite on the equalized model
        and construct a valid equalization_qconfig_dict equalizing only the top
        4 layers with the highest quantization errors.
        """

        torch.manual_seed(1)

        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5))
                self.top = torch.nn.Sequential(torch.nn.Linear(5, 5))

            def forward(self, x):
                x = self.bot(x)
                x = torch.add(x, 5)
                x = self.top(x)
                return x

        float_model = M().eval()
        # Hard coded so that the top layer has a higher quantization error
        x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957],
                          [0.8373, 0.8851, 0.8229, 0.0212, 0.8987],
                          [0.9077, 0.7538, 0.4530, 0.5772, 0.1376],
                          [0.0690, 0.9002, 0.7998, 0.2768, 0.8985],
                          [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]])

        # Quantize the float model
        prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict)
        prepared_model(x)
        quantized_model = convert_fx(copy.deepcopy(prepared_model))

        # Get the SQNR between the float and quantized model
        layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model), quantized_model, x)

        # Construct the equalization_qconfig_dict equalizing layers with the highest
        # quantization errors
        selective_equalization_qconfig_dict = get_equalization_qconfig_dict(layer_to_sqnr_dict, 1)

        # Create the selectively equalized model
        prepared_model = prepare_fx(
            copy.deepcopy(float_model),
            specific_qconfig_dict,
            equalization_qconfig_dict=selective_equalization_qconfig_dict,
        )
        prepared_model(x)
        equalized_model = convert_fx(prepared_model)

        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        # Check the order of nodes in the graph
        self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
예제 #21
0
    def test_input_weight_equalization_graphs(self):
        """ Tests that the modified model for equalization has the same graph
        structure as the model without equalization (before and after
        quantization).
        """

        linear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        linearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_method('dequantize')
        ]

        linearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_method('dequantize')
        ]

        functionalLinearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        conv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        conv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        convRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_method('dequantize')
        ]

        convReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConvRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_method('dequantize')
        ]

        functionalConvReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        tests = [(SingleLayerLinearModel, linear_node_list),
                 (LinearAddModel, linearAdd_node_list),
                 (TwoLayerLinearModel, linear2_node_list),
                 (SingleLayerFunctionalLinearModel, functionalLinear_node_list),
                 (FunctionalLinearAddModel, functionalLinearAdd_node_list),
                 (TwoLayerFunctionalLinearModel, functionalLinear2_node_list),
                 (LinearReluModel, linearRelu_node_list),
                 (LinearReluLinearModel, linearReluLinear_node_list),
                 (FunctionalLinearReluModel, functionalLinearRelu_node_list),
                 (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list),
                 (ConvModel, conv_node_list),
                 (TwoLayerConvModel, conv2_node_list),
                 (SingleLayerFunctionalConvModel, functionalConv_node_list),
                 (TwoLayerFunctionalConvModel, functionalConv2_node_list),
                 (ConvReluModel, convRelu_node_list),
                 (ConvReluConvModel, convReluConv_node_list),
                 (FunctionalConvReluModel, functionalConvRelu_node_list),
                 (FunctionalConvReluConvModel, functionalConvReluConv_node_list)]

        for (M, node_list) in tests:
            m = M().eval()
            prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
            equalized_quantized_model = convert_fx(prepared)

            # Check the order of nodes in the graph
            self.checkGraphModuleNodes(equalized_quantized_model, expected_node_list=node_list)
예제 #22
0
    def test_input_weight_equalization_prepare(self):
        """ Tests that graphs created after prepare_fx is as expected
        """

        single_nn_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 1,
            ns.call_module(MinMaxObserver): 2,
        }

        two_nn_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 2,
            ns.call_module(MinMaxObserver): 3,
        }

        single_F_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 1,
            ns.call_module(_WeightEqualizationObserver): 1,
            ns.call_module(MinMaxObserver): 3,
        }

        two_F_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 2,
            ns.call_module(_WeightEqualizationObserver): 2,
            ns.call_module(MinMaxObserver): 5,
        }

        fp_F_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 2,
            ns.call_module(_WeightEqualizationObserver): 2,
            ns.call_module(MinMaxObserver): 6,
        }

        tests = [(SingleLayerLinearModel, single_nn_layer_node_occurrence),
                 (TwoLayerLinearModel, two_nn_layer_node_occurrence),
                 (TwoLayerFunctionalLinearModel, two_F_layer_node_occurrence),
                 (FunctionalLinearAddModel, fp_F_layer_node_occurrence),
                 (LinearReluModel, single_nn_layer_node_occurrence),
                 (LinearReluLinearModel, two_nn_layer_node_occurrence),
                 (FunctionalLinearReluModel, single_F_layer_node_occurrence),
                 (FunctionalLinearReluLinearModel, two_F_layer_node_occurrence),
                 (ConvModel, single_nn_layer_node_occurrence),
                 (TwoLayerConvModel, two_nn_layer_node_occurrence),
                 (TwoLayerFunctionalConvModel, two_F_layer_node_occurrence),
                 (ConvReluModel, single_nn_layer_node_occurrence),
                 (ConvReluConvModel, two_nn_layer_node_occurrence),
                 (FunctionalConvReluModel, single_F_layer_node_occurrence),
                 (FunctionalConvReluConvModel, two_F_layer_node_occurrence)]

        for (M, node_occurrence) in tests:
            m = M().eval()
            prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
            self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
예제 #23
0
    def test_general_value_ops(self):
        """ A test that checks correct patterns are produced for
        all supported general value ops like aten::avg_pool2d \
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.avg_pool1d = torch.nn.AvgPool1d(3)
                self.avg_pool2d = torch.nn.AvgPool2d(3)
                self.avg_pool3d = torch.nn.AvgPool3d(3)
                self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
                self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d(
                    (1, 1, 1))
                self.leaky_relu = torch.nn.LeakyReLU()
                self.hardsigmoid = torch.nn.Hardsigmoid()
                self.sigmoid = torch.nn.Sigmoid()
                self.tanh = torch.nn.Tanh()

            def forward(self, x):
                x = self.conv(x)
                x = self.avg_pool1d(x)
                x = self.avg_pool2d(x)
                x = self.avg_pool3d(x)
                x = self.adaptive_avg_pool1d(x)
                x = self.adaptive_avg_pool2d(x)
                x = self.adaptive_avg_pool3d(x)
                x = F.avg_pool1d(x, 3)
                x = F.avg_pool2d(x, 3)
                x = F.avg_pool3d(x, 3)
                x = F.adaptive_avg_pool1d(x, (1))
                x = F.adaptive_avg_pool2d(x, (1, 1))
                x = F.adaptive_avg_pool3d(x, (1, 1, 1))
                x = torch.mean(x)
                x = torch.mean(x, [2, 3], False)
                x = x.mean()
                x = x.mean([2, 3], True)
                x = F.interpolate(x, 4, mode='nearest')
                x = F.interpolate(x, 4, mode='linear')
                x = self.leaky_relu(x)
                x = F.leaky_relu(x)
                x = F.leaky_relu(x, inplace=True)
                x = x.leaky_relu()
                x.leaky_relu_()
                x = self.hardsigmoid(x)
                x = F.hardsigmoid(x)
                x = F.hardsigmoid(x, inplace=True)
                x = x.hardsigmoid()
                x.hardsigmoid_()
                x = self.sigmoid(x)
                x = torch.sigmoid(x)
                # F.sigmoid is deprecated
                x = x.sigmoid()
                x.sigmoid_()
                x = self.tanh(x)
                # F.tanh is deprecated
                x = torch.tanh(x)
                x = x.tanh()
                x.tanh_()
                x = self.conv(x)
                return x

        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        qconfig_dict = {'': default_qconfig}
        prepared = prepare_fx(original, qconfig_dict)
        # not runnable
        quantized = convert_fx(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)
예제 #24
0
    def test_general_shape_ops(self):
        """ A test that checks dequantize will be swapped for
        all supported general shape ops like aten::flatten
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
                self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
                self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
                self.dropout = torch.nn.Dropout()
                self.conv1 = torch.nn.Conv2d(3, 3, 3)
                self.conv2 = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv1(x)
                # add_scalar
                x = x + 3
                # mul_scalar
                x = x * 3
                # add_scalar_out
                x += 3
                # mul_scalar_out
                x *= 3
                # add_scalar_relu
                x = x + 3
                x = F.relu(x)
                # add_scalar_relu_out
                x += 3
                x = F.relu(x)
                # mul_scalar_relu
                x = x * 3
                x = F.relu(x)
                # mul_scalar_relu_out
                x *= 3
                x = F.relu(x)
                x = self.maxpool1d(x)
                x = self.maxpool2d(x)
                x = self.maxpool3d(x)
                x = torch.flatten(x)
                x = torch.max(x)
                x = torch.min(x)
                x = x.reshape([-1])
                x = x.resize_(1, 1, x.numel())
                x = x.view(-1)
                # prim::ListConstruct
                xs = [x, x]
                # prim::ListUnpack
                x, y = xs
                # prim::TupleConstruct
                xs = (x, x)
                # prim::TupleUnpack
                x, y = xs
                x = x.transpose(1, 2)
                x = x.contiguous()
                x, y = torch.chunk(x, 2)
                x = F.dropout(x)
                x = self.dropout(x)
                x, _ = torch.sort(x)
                x = x.permute(0, 2, 3, 1)
                x = x.repeat_interleave(3, 1)
                x = torch.repeat_interleave(x, 3, 1)
                x = self.relu(x)
                x = F.relu(x)
                x = F.relu(x, inplace=True)
                x = x.relu()
                x.relu_()
                x = x.squeeze(0)
                x.squeeze_(0)
                x = torch.squeeze(x, 0)
                x = x.unsqueeze(0)
                x.unsqueeze_(0)
                x = torch.unsqueeze(x, 0)
                x = x.detach()
                x.detach_()
                x = x.repeat(4, 2)
                y = []
                y.append(x)
                z = torch.stack(y, 0)
                z = [z, z]
                x, _ = z
                x = self.conv2(x)
                return x

        data = torch.rand(1, 3, 10, 10)
        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        qconfig_dict = {'': default_qconfig}
        prepared = prepare_fx(original, qconfig_dict)
        # not runnable
        quantized = convert_fx(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers and also successfully fused two quantized::conv2d
        # patterns
        # one quantize_per_tensor for input
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)
예제 #25
0
    def test_linear(self):
        class ModuleLinear(torch.nn.Module):
            def __init__(self, has_relu=False, f_relu=False):
                super(ModuleLinear, self).__init__()
                self.linear = torch.nn.Linear(30, 4).float()
                if has_relu:
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()
                else:
                    self.relu = torch.nn.Identity()

            def forward(self, x):
                return self.relu(self.linear(x))

        class FuncLinear(torch.nn.Module):
            def __init__(self, has_relu=False, f_relu=False):
                super(FuncLinear, self).__init__()
                self.w = torch.randn(4, 30)
                self.b = torch.randn(4)
                if has_relu:
                    if f_relu:
                        self.relu = F.relu
                    else:
                        self.relu = torch.nn.ReLU()
                else:
                    self.relu = torch.nn.Identity()

            def forward(self, x):
                return self.relu(F.linear(x, self.w, self.b))

        data = (torch.rand((1, 30), dtype=torch.float), )
        options = itertools.product(
            [(ModuleLinear(has_relu=False), True)],
            # TODO: enable after raw `tensor` is supported in fx
            # (FuncLinear(has_relu=False), False)],
            self.all_quant_types)
        quantized_nodes = {
            # is_module
            True: {
                # quant_type:
                QuantType.DYNAMIC:
                ns.call_module(nnqd.Linear),
                QuantType.STATIC:
                ns.call_module(nnq.Linear),
                # note that we are checking the final result
                QuantType.QAT:
                ns.call_module(nnq.Linear),
            },
            False: {
                # quant_type:
                QuantType.DYNAMIC:
                ns.call_function(torch.ops.quantized.linear_dynamic),
                QuantType.STATIC:
                ns.call_function(torch.ops.quantized.linear),
                QuantType.QAT:
                ns.call_function(torch.ops.quantized.linear),
            }
        }
        for (model, is_module), quant_type in options:
            self.checkGraphModeFxOp(model, data, quant_type,
                                    quantized_nodes[is_module][quant_type])

        for f_relu, quant_type in itertools.product(
            [True, False], [QuantType.STATIC, QuantType.QAT]):
            for model, quantized_node in [(ModuleLinear(has_relu=True,
                                                        f_relu=f_relu),
                                           ns.call_module(nniq.LinearReLU))]:
                # TODO: support functional linear + relu fusion
                # (FuncLinear(has_relu=True, f_relu=f_relu), ns.call_function(torch.ops.quantized.linear_relu))]:
                self.checkGraphModeFxOp(model, data, quant_type,
                                        quantized_node)
예제 #26
0
 def test_standalone_module_quantized_interface(self):
     quantized_interface_config = {
         "input_quantized_idxs": [0],  # quantized input
         "output_quantized_idxs": [0],  # quantized output
     }
     interface_config = quantized_interface_config
     # TODO: input_quantized_idxs only supports quint8, we can remove this
     # custom_backend_config_dict after
     # the `input_quantized_idxs` supports more complicated
     # configurations, as a first step we can change it to use a dictionary from
     # index to dtype
     qconfig = torch.ao.quantization.QConfig(
         activation=torch.ao.quantization.observer.HistogramObserver.with_args(
             qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
         ),
         weight=torch.ao.quantization.default_weight_observer
     )
     weighted_op_quint8_dtype_config = {
         # optional, input activation dtype
         "input_dtype": torch.quint8,
         # optional, weight dtype
         "weight_dtype": torch.qint8,
         # optional, bias dtype
         "bias_dtype": torch.float,
         # optional, output activation dtype
         "output_dtype": torch.quint8
     }
     conv_module_config = {
         "pattern": torch.nn.Conv2d,
         "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
         "dtype_configs": [
             weighted_op_quint8_dtype_config,
         ],
         "root_module": torch.nn.Conv2d,
         "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
     }
     custom_backend_config_dict = {
         "configs": [conv_module_config]
     }
     # observer for input and output of first conv
     prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 2
     }
     # for output of conv in the standalone module
     standalone_prepare_count_check = {
         ns.call_module(torch.ao.quantization.HistogramObserver): 1
     }
     convert_count_check = {
         # quantizing input/output for reference conv
         ns.call_function(torch.quantize_per_tensor) : 2,
         ns.call_module(nnqr.Conv2d) : 1,
         # dequantize the input of reference conv and
         # dequantizing output of standalone module
         ns.call_method("dequantize") : 2,
     }
     standalone_convert_count_check = {
         # quantization of input happens in parent module
         # quantization of output happens in the standalone module
         ns.call_function(torch.quantize_per_tensor) : 1,
         ns.call_module(nnqr.Conv2d): 1,
         # dequantization of input happens in the standalone module
         # dequantization for output happens in parent module
         ns.call_method("dequantize") : 1,
     }
     self._test_standalone_module(
         interface_config,
         prepare_count_check,
         standalone_prepare_count_check,
         convert_count_check,
         standalone_convert_count_check,
         qconfig=qconfig,
         backend_config_dict=custom_backend_config_dict)
예제 #27
0
    def test_conv_add_standalone_module(self):
        class Standalone(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x, y):
                return self.relu(self.conv(x) + y)

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.standalone = Standalone()

            def forward(self, x, y):
                y = self.conv(x)
                return self.standalone(x, y)

        from torch.ao.quantization.fx.backend_config_dict.observation_type import ObservationType
        weighted_op_quint8_dtype_config = {
            # optional, input activation dtype
            # TODO: change back to torch.qint8 after input_quantized_idxs and output_quantized_idxs
            # are more flexible
            "input_dtype": torch.quint8,
            # optional, weight dtype
            "weight_dtype": torch.qint8,
            # optional, bias dtype
            "bias_dtype": torch.float,
            # optional, output activation dtype
            "output_dtype": torch.quint8
        }

        conv_add_config = {
            "pattern": (torch.nn.ReLU, (operator.add, torch.nn.Conv2d, MatchAllNode)),
            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_quint8_dtype_config,
            ],
            "root_module": torch.nn.Conv2d,
            # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
        }

        conv_config = {
            "pattern": torch.nn.Conv2d,
            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_quint8_dtype_config,
            ],
            "root_module": torch.nn.Conv2d,
            # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
        }

        m = M().eval()
        backend_config_dict = {
            "configs": [
                conv_add_config,
                conv_config,
            ]
        }
        prepare_custom_config_dict = {
            "standalone_module_name": [("standalone", None, {"input_quantized_idxs": [0, 1]}, None)]
        }
        # TODO: use self.qconfig after input_quantized_idxs and output_quantized_idxs
        # are more flexible
        qconfig = torch.ao.quantization.QConfig(
            activation=torch.ao.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
            ),
            weight=torch.ao.quantization.default_weight_observer
        )
        m = prepare_fx(
            m,
            {"": qconfig},
            prepare_custom_config_dict=prepare_custom_config_dict,
            backend_config_dict=backend_config_dict)
        node_occurrence = {
            # for input and output of conv, where input is used twice, once in conv and
            # once in standalone module
            ns.call_module(torch.ao.quantization.HistogramObserver): 2,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        standalone_node_occurrence = {
            # output of the standalone module
            ns.call_module(torch.ao.quantization.HistogramObserver): 1,
        }
        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
        m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_module(nn.Conv2d): 1,
            ns.call_method("dequantize"): 1,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        standalone_node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_module(nn.Conv2d): 1,
            ns.call_module(torch.nn.ReLU): 1,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
예제 #28
0
    def test_input_weight_equalization_prepare(self):
        """ Tests that graphs created after prepare_fx is as expected
        """
        qconfig_dict = {
            "":
            None,
            "object_type": [(nn.Linear, default_qconfig),
                            (nn.functional.linear, default_qconfig)]
        }

        default_equalization_qconfig_dict = {
            "":
            default_qconfig,
            "object_type":
            [(nn.Linear, default_equalization_qconfig),
             (nn.functional.linear, default_equalization_qconfig)]
        }

        # Basic test with one linear layer
        class LinearModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(1, 1)

            def forward(self, x):
                return self.linear(x)

        linear_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 1,
            ns.call_module(MinMaxObserver): 2,
        }

        # Test with two linear layers
        class Linear2Module(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear1 = nn.Linear(1, 1)
                self.linear2 = nn.Linear(1, 1)

            def forward(self, x):
                x = self.linear1(x)
                x = self.linear2(x)
                return x

        linear2_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 2,
            ns.call_module(MinMaxObserver): 3,
        }

        class Linear(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.w = torch.ones(5, 5)
                self.b = torch.zeros(5)

            def forward(self, x):
                return nn.functional.linear(x, self.w, self.b)

        # Test where we have two functional linear layers
        class FunctionalLinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear1 = Linear()
                self.linear2 = Linear()

            def forward(self, x):
                x = self.linear1(x)
                x = self.linear2(x)
                return x

        functionalLinear_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 2,
            ns.call_module(_WeightEqualizationObserver): 2,
            ns.call_module(MinMaxObserver): 5,
        }

        # Test where we have a Linear layer followed by a fp32 operation
        # (conv layer) without a qconfig
        class FunctionalLinearFP32Module(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear1 = Linear()
                self.conv = nn.Conv2d(3, 3, 1, 1)
                self.linear2 = Linear()

            def forward(self, x):
                x = self.linear1(x)
                x = torch.add(x, 5)
                x = self.linear2(x)
                return x

        fp32_equalization_qconfig_dict = {
            "":
            None,
            "object_type":
            [(nn.Linear, default_equalization_qconfig),
             (nn.functional.linear, default_equalization_qconfig)]
        }

        functionalLinearFP32_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 2,
            ns.call_module(_WeightEqualizationObserver): 2,
            ns.call_module(MinMaxObserver): 6,
        }

        tests = [(LinearModule, default_equalization_qconfig_dict,
                  linear_node_occurrence),
                 (Linear2Module, fp32_equalization_qconfig_dict,
                  linear2_node_occurrence),
                 (FunctionalLinearModule, default_equalization_qconfig_dict,
                  functionalLinear_node_occurrence),
                 (FunctionalLinearFP32Module, fp32_equalization_qconfig_dict,
                  functionalLinearFP32_node_occurrence)]

        for (M, equalization_qconfig_dict, node_occurrence) in tests:
            m = M().eval()
            prepared = prepare_fx(
                m,
                qconfig_dict,
                equalization_qconfig_dict=equalization_qconfig_dict)
            self.checkGraphModuleNodes(
                prepared, expected_node_occurrence=node_occurrence)
예제 #29
0
    def test_input_weight_equalization_branching(self):
        """ Tests that graphs containing branches are prepared correctly.
        Specifically, equalization observers should not be inserted in front of
        branches in which both initial layers in the branches plan to be
        quantized.
        """

        # Tests that we do not add an equalization observer due to both initial
        # nodes in the branch containing layers that need to be equalized.
        # Note that this should print out 2 warning messages for not being able
        # to equalize layers linear1 and linear1 because it is part of a branch
        class TestBranchingWithoutEqualizationModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = nn.Linear(5, 5)
                self.linear2 = nn.Linear(5, 5)

            def forward(self, x):
                y = self.linear1(x)
                z = self.linear2(x)
                return torch.add(y, z)

        no_eq_branching_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 0,
            ns.call_module(MinMaxObserver): 3,
        }

        m = TestBranchingWithoutEqualizationModel().eval()
        example_inputs = (torch.rand(1, 5), )
        prepared = prepare_fx(
            m,
            specific_qconfig_dict,
            example_inputs=example_inputs,
            _equalization_config=default_equalization_qconfig_dict)
        self.checkGraphModuleNodes(
            prepared, expected_node_occurrence=no_eq_branching_node_occurrence)

        # Tests that we will add an equalization observer because there is only
        # one initial node in the branch that needs to be equalized
        class TestBranchingWithEqualizationModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = nn.Linear(5, 5)

            def forward(self, x):
                y = self.linear1(x)
                z = torch.add(x, 5)
                return torch.add(y, z)

        eq_branching_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 1,
            ns.call_module(MinMaxObserver): 2,
        }

        m = TestBranchingWithEqualizationModel().eval()
        example_inputs = (torch.randn(1, 5), )
        prepared = prepare_fx(
            m,
            specific_qconfig_dict,
            example_inputs=example_inputs,
            _equalization_config=default_equalization_qconfig_dict)
        self.checkGraphModuleNodes(
            prepared, expected_node_occurrence=eq_branching_node_occurrence)
예제 #30
0
    def test_user_module(self):
        """
        For user defined modules,
        1. weight extraction should not crash
        2. unshadowed activations should have loggers, loggers will only log if
             the output dtype is in the allowlist
        3. shadowed activations should not have loggers
             (since I/O dtype is unknown)
        """
        class UserModule(nn.Module):
            def forward(self, x):
                return x

        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(1, 1)
                self.user_module = UserModule()

            def forward(self, x):
                x = self.linear(x)
                x = self.user_module(x)
                return x

        m = M().eval()

        # quantize without tracing through UserModule
        qconfig_dict = {'': torch.quantization.default_qconfig}
        prepare_custom_config_dict = {
            'non_traceable_module_name': ['user_module']
        }
        mp = prepare_fx(m, qconfig_dict, prepare_custom_config_dict)
        mp(torch.randn(1, 1, 1))
        mq = convert_fx(copy.deepcopy(mp))

        # weight extraction should not crash
        weights = _extract_weights_impl('fp32_prepared', mp, 'int8', mq)

        # unshadowed activations should have loggers

        # add loggers, without retracing
        # note: converting again because we cannot copy a quantized linear
        mp_ns, mq_ns = _add_loggers_impl('fp32_prepared',
                                         copy.deepcopy(mp),
                                         'int8',
                                         convert_fx(copy.deepcopy(mp)),
                                         OutputLogger,
                                         should_log_inputs=True)
        # both fp32 and int8 models should have 4 loggers each, 2 for I/O
        # of linear, and 2 for I/O of user_module
        unshadowed_expected_occurrence = {
            ns.call_module(OutputLogger): 4,
        }
        self.checkGraphModuleNodes(
            mp_ns, expected_node_occurrence=unshadowed_expected_occurrence)
        self.checkGraphModuleNodes(
            mq_ns, expected_node_occurrence=unshadowed_expected_occurrence)

        # shadowed activations should only have loggers for nodes where
        # the types are known and we can do a dtype cast

        # add shadow loggers, without retracing
        mp_shadows_mq_ns = _add_shadow_loggers_impl('fp32_prepared',
                                                    mp,
                                                    'int8',
                                                    mq,
                                                    OutputLogger,
                                                    should_log_inputs=True)
        # 2 loggers for I/O of linear, 0 loggers for I/O of user_module
        shadowed_expected_occurrence = {
            ns.call_module(OutputLogger): 2,
        }
        self.checkGraphModuleNodes(
            mp_shadows_mq_ns,
            expected_node_occurrence=unshadowed_expected_occurrence)