예제 #1
0
 def test_weight_only_activation_only_fakequant(self):
     for qengine in ["fbgemm", "qnnpack"]:
         if qengine not in torch.backends.quantized.supported_engines:
             continue
         if qengine == 'qnnpack':
             if IS_PPC or TEST_WITH_UBSAN:
                 continue
         with override_quantized_engine(qengine):
             torch.manual_seed(67)
             calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
             eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
             qconfigset = set([torch.quantization.default_weight_only_quant_qconfig,
                               torch.quantization.default_activation_only_quant_qconfig])
             SQNRTarget = [35, 45]
             for idx, qconfig in enumerate(qconfigset):
                 my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
                 my_model.eval()
                 out_ref = my_model(eval_data)
                 fq_model = torch.quantization.QuantWrapper(my_model)
                 fq_model.train()
                 fq_model.qconfig = qconfig
                 torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
                 torch.quantization.prepare_qat(fq_model)
                 fq_model.eval()
                 fq_model.apply(torch.quantization.disable_fake_quant)
                 fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
                 fq_model(calib_data)
                 fq_model.apply(torch.quantization.enable_fake_quant)
                 fq_model.apply(torch.quantization.disable_observer)
                 out_fq = fq_model(eval_data)
                 SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
                 self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float')
예제 #2
0
 def test_float_quant_compare_per_tensor(self):
     for qengine in ["fbgemm", "qnnpack"]:
         if qengine not in torch.backends.quantized.supported_engines:
             continue
         if qengine == 'qnnpack':
             if IS_PPC or TEST_WITH_UBSAN:
                 continue
         with override_quantized_engine(qengine):
             torch.manual_seed(42)
             my_model = ModelMultipleOps().to(torch.float32)
             my_model.eval()
             calib_data = torch.rand(1024, 3, 15, 15, dtype=torch.float32)
             eval_data = torch.rand(1, 3, 15, 15, dtype=torch.float32)
             out_ref = my_model(eval_data)
             qModel = torch.quantization.QuantWrapper(my_model)
             qModel.eval()
             qModel.qconfig = torch.quantization.default_qconfig
             torch.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True)
             torch.quantization.prepare(qModel, inplace=True)
             qModel(calib_data)
             torch.quantization.convert(qModel, inplace=True)
             out_q = qModel(eval_data)
             SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
             # Quantized model output should be close to floating point model output numerically
             # Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired
             # output
             self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB')
예제 #3
0
 def test_fake_quant_true_quant_compare(self):
     for qengine in ["fbgemm", "qnnpack"]:
         if qengine not in torch.backends.quantized.supported_engines:
             continue
         if qengine == 'qnnpack':
             if IS_PPC or TEST_WITH_UBSAN:
                 continue
         with override_quantized_engine(qengine):
             torch.manual_seed(67)
             my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
             calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
             eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
             my_model.eval()
             out_ref = my_model(eval_data)
             fq_model = torch.quantization.QuantWrapper(my_model)
             fq_model.train()
             fq_model.qconfig = torch.quantization.default_qat_qconfig
             torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
             torch.quantization.prepare_qat(fq_model)
             fq_model.eval()
             fq_model.apply(torch.quantization.disable_fake_quant)
             fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
             fq_model(calib_data)
             fq_model.apply(torch.quantization.enable_fake_quant)
             fq_model.apply(torch.quantization.disable_observer)
             out_fq = fq_model(eval_data)
             SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
             # Quantized model output should be close to floating point model output numerically
             # Setting target SQNR to be 35 dB
             self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
             torch.quantization.convert(fq_model)
             out_q = fq_model(eval_data)
             SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10))
             self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB')
예제 #4
0
    def test_conv3d_api(
        self,
        batch_size,
        in_channels_per_group,
        D,
        H,
        W,
        out_channels_per_group,
        groups,
        kernel_d,
        kernel_h,
        kernel_w,
        stride_d,
        stride_h,
        stride_w,
        pad_d,
        pad_h,
        pad_w,
        dilation,
        X_scale,
        X_zero_point,
        W_scale,
        W_zero_point,
        Y_scale,
        Y_zero_point,
        use_bias,
        use_channelwise,
        qengine,
    ):
        # Tests the correctness of the conv3d function.
        # Currently conv3d only supports FbGemm engine

        if qengine not in torch.backends.quantized.supported_engines:
            return

        input_feature_map_size = (D, H, W)
        kernel_size = (kernel_d, kernel_h, kernel_w)
        stride = (stride_d, stride_h, stride_w)
        padding = (pad_d, pad_h, pad_w)
        dilation = (dilation, dilation, dilation)

        with override_quantized_engine(qengine):
            qconv_fn = qF.conv3d
            conv_fn = F.conv3d
            self._test_conv_api_impl(
                qconv_fn, conv_fn, batch_size, in_channels_per_group,
                input_feature_map_size, out_channels_per_group, groups,
                kernel_size, stride, padding, dilation, X_scale, X_zero_point,
                W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
                use_channelwise)
예제 #5
0
    def test_conv2d_api(
        self,
        batch_size,
        in_channels_per_group,
        H,
        W,
        out_channels_per_group,
        groups,
        kernel_h,
        kernel_w,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation,
        X_scale,
        X_zero_point,
        W_scale,
        W_zero_point,
        Y_scale,
        Y_zero_point,
        use_bias,
        use_channelwise,
        qengine,
    ):
        # Tests the correctness of the conv2d function.

        if qengine not in torch.backends.quantized.supported_engines:
            return
        if qengine == 'qnnpack':
            if IS_PPC or TEST_WITH_UBSAN:
                return
            use_channelwise = False

        input_feature_map_size = (H, W)
        kernel_size = (kernel_h, kernel_w)
        stride = (stride_h, stride_w)
        padding = (pad_h, pad_w)
        dilation = (dilation, dilation)

        with override_quantized_engine(qengine):
            qconv_fn = qF.conv2d
            conv_fn = F.conv2d
            self._test_conv_api_impl(
                qconv_fn, conv_fn, batch_size, in_channels_per_group,
                input_feature_map_size, out_channels_per_group, groups,
                kernel_size, stride, padding, dilation, X_scale, X_zero_point,
                W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
                use_channelwise)
예제 #6
0
    def test_conv3d_api(
        self,
        batch_size,
        in_channels_per_group,
        D,
        H,
        W,
        out_channels_per_group,
        groups,
        kernel_d,
        kernel_h,
        kernel_w,
        stride_d,
        stride_h,
        stride_w,
        pad_d,
        pad_h,
        pad_w,
        dilation,
        X_scale,
        X_zero_point,
        W_scale,
        W_zero_point,
        Y_scale,
        Y_zero_point,
        use_bias,
        use_channelwise,
        use_fused,
        qengine,
    ):
        # Tests the correctness of the conv3d module.
        if qengine not in torch.backends.quantized.supported_engines:
            return

        in_channels = in_channels_per_group * groups
        out_channels = out_channels_per_group * groups
        input_feature_map_size = (D, H, W)
        kernel_size = (kernel_d, kernel_h, kernel_w)
        stride = (stride_d, stride_h, stride_w)
        padding = (pad_d, pad_h, pad_w)
        dilation = (dilation, dilation, dilation)

        with override_quantized_engine(qengine):
            if use_fused:
                module_name = "QuantizedConvReLU3d"
                qconv_module = nnq_fused.ConvReLU3d(in_channels,
                                                    out_channels,
                                                    kernel_size,
                                                    stride,
                                                    padding,
                                                    dilation,
                                                    groups,
                                                    use_bias,
                                                    padding_mode="zeros")
            else:
                module_name = "QuantizedConv3d"
                qconv_module = nnq.Conv3d(in_channels,
                                          out_channels,
                                          kernel_size,
                                          stride,
                                          padding,
                                          dilation,
                                          groups,
                                          use_bias,
                                          padding_mode="zeros")

            conv_module = nn.Conv3d(in_channels,
                                    out_channels,
                                    kernel_size,
                                    stride,
                                    padding,
                                    dilation,
                                    groups,
                                    use_bias,
                                    padding_mode="zeros")
            if use_fused:
                relu_module = nn.ReLU()
                conv_module = nni.ConvReLU3d(conv_module, relu_module)
            conv_module = conv_module.float()

            self._test_conv_api_impl(
                module_name, qconv_module, conv_module, batch_size,
                in_channels_per_group, input_feature_map_size,
                out_channels_per_group, groups, kernel_size, stride, padding,
                dilation, X_scale, X_zero_point, W_scale, W_zero_point,
                Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise)
예제 #7
0
    def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_fused, per_channel, qengine):
        """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu"""
        if qengine not in torch.backends.quantized.supported_engines:
            return
        if qengine == 'qnnpack':
            if IS_PPC or TEST_WITH_UBSAN:
                return
            per_channel = False
        with override_quantized_engine(qengine):
            W = torch.rand(out_features, in_features).float()
            if per_channel:
                scale_tensor = torch.ones(out_features, dtype=torch.double)
                zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
                for i in range(len(scale_tensor)):
                    scale_tensor[i] = (i + 1.0) / 255.0
                W_q = torch.quantize_per_channel(W,
                                                 scales=scale_tensor,
                                                 zero_points=zero_point_tensor,
                                                 axis=0,
                                                 dtype=torch.qint8)
            else:
                W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)

            X = torch.rand(batch_size, in_features).float()
            X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
            B = torch.rand(out_features).float() if use_bias else None
            scale = 0.5
            zero_point = 3
            if use_fused:
                qlinear = nnq_fused.LinearReLU(in_features, out_features)
            else:
                qlinear = nnq.Linear(in_features, out_features)

            # Run module with default-initialized parameters.
            # This tests that the constructor is correct.
            qlinear(X_q)

            qlinear.set_weight_bias(W_q, B)
            # Simple round-trip test to ensure weight()/set_weight() API
            self.assertEqual(qlinear.weight(), W_q)
            W_pack = qlinear._packed_params._packed_params

            qlinear.scale = float(scale)
            qlinear.zero_point = int(zero_point)
            Z_q = qlinear(X_q)
            # Check if the module implementation matches calling the
            # ops directly
            if use_fused:
                Z_ref = torch.ops.quantized.linear_relu(
                    X_q, W_pack, scale, zero_point)

                self.assertTrue('QuantizedLinearReLU' in str(qlinear))
            else:
                Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale,
                                                   zero_point)

                self.assertTrue('QuantizedLinear' in str(qlinear))
            self.assertEqual(Z_ref, Z_q)

            # Test serialization of quantized Linear Module using state_dict
            model_dict = qlinear.state_dict()
            self.assertEqual(model_dict['_packed_params.weight'], W_q)
            if use_bias:
                self.assertEqual(model_dict['_packed_params.bias'], B)
            b = io.BytesIO()
            torch.save(model_dict, b)
            b.seek(0)
            loaded_dict = torch.load(b)
            for key in model_dict:
                self.assertEqual(model_dict[key], loaded_dict[key])
            if use_fused:
                loaded_qlinear = nnq_fused.LinearReLU(in_features,
                                                      out_features)
            else:
                loaded_qlinear = nnq.Linear(in_features, out_features)
            loaded_qlinear.load_state_dict(loaded_dict)

            linear_unpack = torch.ops.quantized.linear_unpack
            self.assertEqual(
                linear_unpack(qlinear._packed_params._packed_params),
                linear_unpack(loaded_qlinear._packed_params._packed_params))
            if use_bias:
                self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
            self.assertEqual(qlinear.scale, loaded_qlinear.scale)
            self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
            self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
            self.assertTrue(hasattr(qlinear, '_packed_params'))
            self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
            self.assertTrue(hasattr(qlinear, '_weight_bias'))
            self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))
            self.assertEqual(qlinear._weight_bias(),
                             loaded_qlinear._weight_bias())
            self.assertEqual(
                qlinear._weight_bias(),
                torch.ops.quantized.linear_unpack(
                    qlinear._packed_params._packed_params))
            Z_q2 = loaded_qlinear(X_q)
            self.assertEqual(Z_q, Z_q2)

            # The below check is meant to ensure that `torch.save` and `torch.load`
            # serialization works, however it is currently broken by the following:
            # https://github.com/pytorch/pytorch/issues/24045
            #
            # Instead, we currently check that the proper exception is thrown on save.
            # <start code>
            # b = io.BytesIO()
            # torch.save(qlinear, b)
            # b.seek(0)
            # loaded = torch.load(b)
            # self.assertEqual(qlinear.weight(), loaded.weight())
            # self.assertEqual(qlinear.scale, loaded.scale)
            # self.assertEqual(qlinear.zero_point, loaded.zero_point)
            # <end code>
            with self.assertRaisesRegex(
                    RuntimeError,
                    r'torch.save\(\) is not currently supported'):
                b = io.BytesIO()
                torch.save(qlinear, b)

            # Test JIT
            self.checkScriptable(qlinear,
                                 list(zip([X_q], [Z_ref])),
                                 check_save_load=True)

            # Test from_float.
            float_linear = torch.nn.Linear(in_features, out_features).float()
            float_linear.qconfig = torch.quantization.default_qconfig
            torch.quantization.prepare(float_linear, inplace=True)
            float_linear(X.float())
            # Sequential allows swapping using "convert".
            quantized_float_linear = torch.nn.Sequential(float_linear)
            quantized_float_linear = torch.quantization.convert(
                quantized_float_linear, inplace=True)

            # Smoke test to make sure the module actually runs
            quantized_float_linear(X_q)

            # Smoke test extra_repr
            self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
예제 #8
0
    def test_conv_api(self, use_bias, use_fused, per_channel, qengine):
        """Tests the correctness of the conv module.
        The correctness is defined against the functional implementation.
        """
        if qengine not in torch.backends.quantized.supported_engines:
            return
        if qengine == 'qnnpack':
            if IS_PPC or TEST_WITH_UBSAN:
                return
            per_channel = False
        with override_quantized_engine(qengine):
            N, iC, H, W = 10, 10, 10, 3
            oC, g, kH, kW = 16, 1, 3, 3
            scale, zero_point = 1.0 / 255, 128

            X = torch.randn(N, iC, H, W, dtype=torch.float32)
            qX = torch.quantize_per_tensor(X,
                                           scale=scale,
                                           zero_point=128,
                                           dtype=torch.quint8)

            w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

            if per_channel:
                scale_tensor = torch.ones(oC, dtype=torch.double)
                zero_point_tensor = torch.zeros(oC, dtype=torch.long)
                for i in range(len(scale_tensor)):
                    scale_tensor[i] = (i + 1.0) / 255.0

                qw = torch.quantize_per_channel(w,
                                                scales=scale_tensor,
                                                zero_points=zero_point_tensor,
                                                axis=0,
                                                dtype=torch.qint8)
            else:
                qw = torch.quantize_per_tensor(w,
                                               scale=scale,
                                               zero_point=0,
                                               dtype=torch.qint8)

            b = torch.randn(oC, dtype=torch.float32) if use_bias else None

            if use_fused:
                conv_under_test = ConvReLU2d(in_channels=iC,
                                             out_channels=oC,
                                             kernel_size=(kH, kW),
                                             stride=1,
                                             padding=0,
                                             dilation=1,
                                             groups=g,
                                             bias=use_bias,
                                             padding_mode='zeros')
            else:
                conv_under_test = Conv2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros')
            # Run module with default-initialized parameters.
            # This tests that the constructor is correct.
            conv_under_test.set_weight_bias(qw, b)
            conv_under_test(qX)

            conv_under_test.scale = scale
            conv_under_test.zero_point = zero_point

            # Test members
            self.assertTrue(hasattr(conv_under_test, '_packed_params'))
            self.assertTrue(hasattr(conv_under_test, 'scale'))
            self.assertTrue(hasattr(conv_under_test, 'zero_point'))

            # Test properties
            self.assertEqual(qw, conv_under_test.weight())
            if use_bias:
                self.assertEqual(b, conv_under_test.bias())
            self.assertEqual(scale, conv_under_test.scale)
            self.assertEqual(zero_point, conv_under_test.zero_point)

            # Test forward
            result_under_test = conv_under_test(qX)
            result_reference = qF.conv2d(qX,
                                         qw,
                                         bias=b,
                                         scale=scale,
                                         zero_point=zero_point,
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         dtype=torch.quint8)
            if use_fused:
                # result_reference < zero_point doesn't work for qtensor yet
                # result_reference[result_reference < zero_point] = zero_point
                MB, OC, OH, OW = result_reference.size()
                for i in range(MB):
                    for j in range(OC):
                        for h in range(OH):
                            for w in range(OW):
                                if result_reference[i][j][h][w].int_repr(
                                ) < zero_point:
                                    # assign 0. that gets converted to zero_point
                                    result_reference[i][j][h][w] = 0.

            self.assertEqual(result_reference,
                             result_under_test,
                             message="Tensors are not equal.")

            # Test serialization of quantized Conv Module using state_dict
            model_dict = conv_under_test.state_dict()
            self.assertEqual(model_dict['weight'], qw)
            if use_bias:
                self.assertEqual(model_dict['bias'], b)
            b = io.BytesIO()
            torch.save(model_dict, b)
            b.seek(0)
            loaded_dict = torch.load(b)
            for key in model_dict:
                self.assertEqual(loaded_dict[key], model_dict[key])
            if use_fused:
                loaded_conv_under_test = ConvReLU2d(in_channels=iC,
                                                    out_channels=oC,
                                                    kernel_size=(kH, kW),
                                                    stride=1,
                                                    padding=0,
                                                    dilation=1,
                                                    groups=g,
                                                    bias=use_bias,
                                                    padding_mode='zeros')

                self.assertTrue(
                    'QuantizedConvReLU2d' in str(loaded_conv_under_test))
            else:
                loaded_conv_under_test = Conv2d(in_channels=iC,
                                                out_channels=oC,
                                                kernel_size=(kH, kW),
                                                stride=1,
                                                padding=0,
                                                dilation=1,
                                                groups=g,
                                                bias=use_bias,
                                                padding_mode='zeros')
                self.assertTrue(
                    'QuantizedConv2d' in str(loaded_conv_under_test))
            loaded_conv_under_test.load_state_dict(loaded_dict)
            self.assertEqual(loaded_conv_under_test._weight_bias(),
                             conv_under_test._weight_bias())
            if use_bias:
                self.assertEqual(loaded_conv_under_test.bias(),
                                 conv_under_test.bias())
            self.assertEqual(loaded_conv_under_test.scale,
                             conv_under_test.scale)
            self.assertEqual(loaded_conv_under_test.zero_point,
                             conv_under_test.zero_point)
            self.assertTrue(
                dir(loaded_conv_under_test) == dir(conv_under_test))
            self.assertTrue(hasattr(conv_under_test, '_packed_params'))
            self.assertTrue(hasattr(loaded_conv_under_test, '_packed_params'))
            self.assertTrue(hasattr(conv_under_test, '_weight_bias'))
            self.assertTrue(hasattr(loaded_conv_under_test, '_weight_bias'))
            self.assertEqual(loaded_conv_under_test._weight_bias(),
                             conv_under_test._weight_bias())
            self.assertEqual(loaded_conv_under_test.weight(), qw)
            loaded_result = loaded_conv_under_test(qX)
            self.assertEqual(loaded_result, result_reference)

            # The below check is meant to ensure that `torch.save` and `torch.load`
            # serialization works, however it is currently broken by the following:
            # https://github.com/pytorch/pytorch/issues/24045
            #
            # Instead, we currently check that the proper exception is thrown on save.
            # <start code>
            # b = io.BytesIO()
            # torch.save(conv_under_test, b)
            # b.seek(0)
            # loaded_conv = torch.load(b)
            #
            # self.assertEqual(conv_under_test.bias(), loaded_conv.bias())
            # self.assertEqual(conv_under_test.scale, loaded_conv.scale)
            # self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point)
            # <end code>
            with self.assertRaisesRegex(
                    RuntimeError,
                    r'torch.save\(\) is not currently supported'):
                b = io.BytesIO()
                torch.save(conv_under_test, b)

            # JIT testing
            self.checkScriptable(conv_under_test,
                                 list(zip([qX], [result_reference])),
                                 check_save_load=True)

            # Test from_float
            float_conv = torch.nn.Conv2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros').float()
            float_conv.qconfig = torch.quantization.default_qconfig
            torch.quantization.prepare(float_conv, inplace=True)
            float_conv(X.float())
            quantized_float_conv = torch.nn.Sequential(float_conv)
            torch.quantization.convert(quantized_float_conv, inplace=True)

            # Smoke test to make sure the module actually runs
            quantized_float_conv(qX)
            if use_bias:
                self.assertEqual(quantized_float_conv[0].bias(),
                                 float_conv.bias)
            # Smoke test extra_repr
            self.assertTrue('QuantizedConv2d' in str(quantized_float_conv))
예제 #9
0
    def test_conv_api(self, use_bias, per_channel, qengine):
        """Tests the correctness of the conv module.
        The correctness is defined against the functional implementation.
        """

        if qengine not in torch.backends.quantized.supported_engines:
            return
        if qengine == 'qnnpack':
            if IS_PPC or TEST_WITH_UBSAN:
                return
            per_channel = False
        with override_quantized_engine(qengine):
            N, iC, H, W = 10, 10, 10, 3
            oC, g, kH, kW = 16, 1, 3, 3
            scale, zero_point = 1.0 / 255, 128
            stride = (1, 1)
            i_padding = (0, 0)
            dilation = (1, 1)

            X = torch.randn(N, iC, H, W, dtype=torch.float32)
            qX = torch.quantize_per_tensor(X,
                                           scale=scale,
                                           zero_point=128,
                                           dtype=torch.quint8)

            w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

            if per_channel:
                scale_tensor = torch.ones(oC, dtype=torch.double)
                zero_point_tensor = torch.zeros(oC, dtype=torch.long)
                for i in range(len(scale_tensor)):
                    scale_tensor[i] = (i + 1.0) / 255.0

                qw = torch.quantize_per_channel(w,
                                                scales=scale_tensor,
                                                zero_points=zero_point_tensor,
                                                axis=0,
                                                dtype=torch.qint8)
            else:
                qw = torch.quantize_per_tensor(w,
                                               scale=scale,
                                               zero_point=0,
                                               dtype=torch.qint8)

            b = torch.randn(oC, dtype=torch.float32) if use_bias else None
            q_filters_ref = torch.ops.quantized.conv_prepack(
                qw, b, stride, i_padding, dilation, g)

            ref_result = torch.ops.quantized.conv2d(qX, q_filters_ref, stride,
                                                    i_padding, dilation, g,
                                                    scale, zero_point)

            q_result = torch.nn.quantized.functional.conv2d(
                qX,
                qw,
                bias=b,
                scale=scale,
                zero_point=zero_point,
                stride=stride,
                padding=i_padding,
                dilation=dilation,
                groups=g,
                dtype=torch.quint8)

            self.assertEqual(ref_result, q_result)