Ejemplo n.º 1
0
    def test_conv_api(self):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        X = X.permute([0, 2, 3, 1]).contiguous()
        qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)
        w = w.permute([0, 2, 3, 1]).contiguous()
        qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32)
        qb = torch.quantize_linear(b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32)

        conv_under_test = Conv2d(in_channels=iC,
                                 out_channels=oC,
                                 kernel_size=(kH, kW),
                                 stride=1,
                                 padding=0,
                                 dilation=1,
                                 groups=g,
                                 bias=True,
                                 padding_mode='zeros')
        conv_under_test.weight = qw
        conv_under_test.bias = qb
        conv_under_test.scale = scale
        conv_under_test.zero_point = zero_point

        # Test members
        self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(conv_under_test, '_scale'))
        self.assertTrue(hasattr(conv_under_test, '_zero_point'))

        # Test properties
        # self.assertEqual(qw, conv_under_test.weight)
        self.assertEqual(qb, conv_under_test.bias)
        self.assertEqual(scale, conv_under_test.scale)
        self.assertEqual(zero_point, conv_under_test.zero_point)

        # Test forward
        result_under_test = conv_under_test(qX)
        result_reference = qF.conv2d(qX, qw, bias=qb,
                                     scale=scale, zero_point=zero_point,
                                     stride=1, padding=0,
                                     dilation=1, groups=g,
                                     prepacked=False, dtype=torch.quint8)

        self.assertEqual(result_reference, result_under_test,
                         message="Tensors are not equal.")
Ejemplo n.º 2
0
 def forward(self, input):
     return qF.conv2d(input=input,
                      weight=self._packed_weight,
                      bias=self.bias,
                      stride=self.stride,
                      padding=self.padding,
                      dilation=self.dilation,
                      groups=self.groups,
                      padding_mode=self.padding_mode,
                      scale=self.scale,
                      zero_point=self.zero_point,
                      dtype=torch.quint8,
                      prepacked=True)
Ejemplo n.º 3
0
    def test_conv_api(self, use_bias, use_fused):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        X = X.permute([0, 2, 3, 1]).contiguous()
        qX = torch.quantize_linear(X,
                                   scale=scale,
                                   zero_point=128,
                                   dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        qw = torch.quantize_linear(w,
                                   scale=scale,
                                   zero_point=0,
                                   dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None
        qb = torch.quantize_linear(
            b, scale=1.0 /
            1024, zero_point=0, dtype=torch.qint32) if use_bias else None

        if use_fused:
            conv_under_test = ConvReLU2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros')
        else:
            conv_under_test = Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros')
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        conv_under_test(qX)

        conv_under_test.set_weight(qw)
        conv_under_test.bias = qb
        conv_under_test.scale = scale
        conv_under_test.zero_point = zero_point

        # Test members
        self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(conv_under_test, 'scale'))
        self.assertTrue(hasattr(conv_under_test, 'zero_point'))

        # Test properties
        self.assertEqual(qw, conv_under_test.weight())
        self.assertEqual(qb, conv_under_test.bias)
        self.assertEqual(scale, conv_under_test.scale)
        self.assertEqual(zero_point, conv_under_test.zero_point)

        # Test forward
        result_under_test = conv_under_test(qX)
        result_reference = qF.conv2d(qX,
                                     qw,
                                     bias=qb,
                                     scale=scale,
                                     zero_point=zero_point,
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     dtype=torch.quint8)
        if use_fused:
            # result_reference < zero_point doesn't work for qtensor yet
            # result_reference[result_reference < zero_point] = zero_point
            MB, OC, OH, OW = result_reference.size()
            for i in range(MB):
                for j in range(OC):
                    for h in range(OH):
                        for w in range(OW):
                            if result_reference[i][j][h][w].int_repr(
                            ) < zero_point:
                                # assign 0. that gets converted to zero_point
                                result_reference[i][j][h][w] = 0.

        self.assertEqual(result_reference,
                         result_under_test,
                         message="Tensors are not equal.")

        # Test serialization of quantized Conv Module using state_dict
        model_dict = conv_under_test.state_dict()
        self.assertEqual(model_dict['weight'], qw)
        if use_bias:
            self.assertEqual(model_dict['bias'], qb)
        with tempfile.NamedTemporaryFile() as f:
            torch.save(model_dict, f)
            f.seek(0)
            loaded_dict = torch.load(f)
        for key in model_dict:
            self.assertEqual(loaded_dict[key], model_dict[key])
        if use_fused:
            loaded_conv_under_test = ConvReLU2d(in_channels=iC,
                                                out_channels=oC,
                                                kernel_size=(kH, kW),
                                                stride=1,
                                                padding=0,
                                                dilation=1,
                                                groups=g,
                                                bias=use_bias,
                                                padding_mode='zeros')
        else:
            loaded_conv_under_test = Conv2d(in_channels=iC,
                                            out_channels=oC,
                                            kernel_size=(kH, kW),
                                            stride=1,
                                            padding=0,
                                            dilation=1,
                                            groups=g,
                                            bias=use_bias,
                                            padding_mode='zeros')
        loaded_conv_under_test.load_state_dict(loaded_dict)
        self.assertEqual(loaded_conv_under_test.weight(),
                         conv_under_test.weight())
        if use_bias:
            self.assertEqual(loaded_conv_under_test.bias, conv_under_test.bias)
        self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale)
        self.assertEqual(loaded_conv_under_test.zero_point,
                         conv_under_test.zero_point)
        self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test))
        self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(loaded_conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(conv_under_test, 'weight'))
        self.assertTrue(hasattr(loaded_conv_under_test, 'weight'))
        self.assertEqual(loaded_conv_under_test.weight(),
                         conv_under_test.weight())
        self.assertEqual(loaded_conv_under_test.weight(), qw)
        loaded_result = loaded_conv_under_test(qX)
        self.assertEqual(loaded_result, result_reference)

        with tempfile.NamedTemporaryFile() as f:
            torch.save(conv_under_test, f)
            f.seek(0)
            loaded_conv = torch.load(f)

        self.assertEqual(conv_under_test.bias, loaded_conv.bias)
        self.assertEqual(conv_under_test.scale, loaded_conv.scale)
        self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point)

        # JIT testing
        self.checkScriptable(conv_under_test,
                             list(zip([qX], [result_reference])),
                             check_save_load=True)

        # Test from_float
        float_conv = torch.nn.Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros').float()
        float_conv.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_conv)
        float_conv(X.float())
        quantized_float_conv = torch.nn.Sequential(float_conv)
        torch.quantization.convert(quantized_float_conv)

        # Smoke test to make sure the module actually runs
        quantized_float_conv(qX)
        # Check that bias is quantized based on output scale
        if use_bias:
            qbias = torch.quantize_linear(
                float_conv.bias, quantized_float_conv[0].scale / 2**16, 0,
                torch.qint32)
            self.assertEqual(quantized_float_conv[0].bias.dequantize(),
                             qbias.dequantize())
        # Smoke test extra_repr
        str(quantized_float_conv)
Ejemplo n.º 4
0
    def test_conv_api(self, use_bias, use_fused):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        qX = torch.quantize_per_tensor(X,
                                       scale=scale,
                                       zero_point=128,
                                       dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        qw = torch.quantize_per_tensor(w,
                                       scale=scale,
                                       zero_point=0,
                                       dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None

        if use_fused:
            conv_under_test = ConvReLU2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros')
        else:
            conv_under_test = Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros')
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        conv_under_test.set_weight_bias(qw, b)
        conv_under_test(qX)

        conv_under_test.scale = scale
        conv_under_test.zero_point = zero_point

        # Test members
        self.assertTrue(hasattr(conv_under_test, '_packed_params'))
        self.assertTrue(hasattr(conv_under_test, 'scale'))
        self.assertTrue(hasattr(conv_under_test, 'zero_point'))

        # Test properties
        self.assertEqual(qw, conv_under_test.weight())
        self.assertEqual(b, conv_under_test.bias())
        self.assertEqual(scale, conv_under_test.scale)
        self.assertEqual(zero_point, conv_under_test.zero_point)

        # Test forward
        result_under_test = conv_under_test(qX)
        result_reference = qF.conv2d(qX,
                                     qw,
                                     bias=b,
                                     scale=scale,
                                     zero_point=zero_point,
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     dtype=torch.quint8)
        if use_fused:
            # result_reference < zero_point doesn't work for qtensor yet
            # result_reference[result_reference < zero_point] = zero_point
            MB, OC, OH, OW = result_reference.size()
            for i in range(MB):
                for j in range(OC):
                    for h in range(OH):
                        for w in range(OW):
                            if result_reference[i][j][h][w].int_repr(
                            ) < zero_point:
                                # assign 0. that gets converted to zero_point
                                result_reference[i][j][h][w] = 0.

        self.assertEqual(result_reference,
                         result_under_test,
                         message="Tensors are not equal.")

        # Test serialization of quantized Conv Module using state_dict
        model_dict = conv_under_test.state_dict()
        self.assertEqual(model_dict['weight'], qw)
        if use_bias:
            self.assertEqual(model_dict['bias'], b)
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            self.assertEqual(loaded_dict[key], model_dict[key])
        if use_fused:
            loaded_conv_under_test = ConvReLU2d(in_channels=iC,
                                                out_channels=oC,
                                                kernel_size=(kH, kW),
                                                stride=1,
                                                padding=0,
                                                dilation=1,
                                                groups=g,
                                                bias=use_bias,
                                                padding_mode='zeros')
        else:
            loaded_conv_under_test = Conv2d(in_channels=iC,
                                            out_channels=oC,
                                            kernel_size=(kH, kW),
                                            stride=1,
                                            padding=0,
                                            dilation=1,
                                            groups=g,
                                            bias=use_bias,
                                            padding_mode='zeros')
        loaded_conv_under_test.load_state_dict(loaded_dict)
        self.assertEqual(loaded_conv_under_test._weight_bias(),
                         conv_under_test._weight_bias())
        if use_bias:
            self.assertEqual(loaded_conv_under_test.bias(),
                             conv_under_test.bias())
        self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale)
        self.assertEqual(loaded_conv_under_test.zero_point,
                         conv_under_test.zero_point)
        self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test))
        self.assertTrue(hasattr(conv_under_test, '_packed_params'))
        self.assertTrue(hasattr(loaded_conv_under_test, '_packed_params'))
        self.assertTrue(hasattr(conv_under_test, '_weight_bias'))
        self.assertTrue(hasattr(loaded_conv_under_test, '_weight_bias'))
        self.assertEqual(loaded_conv_under_test._weight_bias(),
                         conv_under_test._weight_bias())
        self.assertEqual(loaded_conv_under_test.weight(), qw)
        loaded_result = loaded_conv_under_test(qX)
        self.assertEqual(loaded_result, result_reference)

        # The below check is meant to ensure that `torch.save` and `torch.load`
        # serialization works, however it is currently broken by the following:
        # https://github.com/pytorch/pytorch/issues/24045
        #
        # Instead, we currently check that the proper exception is thrown on save.
        # <start code>
        # b = io.BytesIO()
        # torch.save(conv_under_test, b)
        # b.seek(0)
        # loaded_conv = torch.load(b)
        #
        # self.assertEqual(conv_under_test.bias(), loaded_conv.bias())
        # self.assertEqual(conv_under_test.scale, loaded_conv.scale)
        # self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point)
        # <end code>
        with self.assertRaisesRegex(
                RuntimeError, r'torch.save\(\) is not currently supported'):
            b = io.BytesIO()
            torch.save(conv_under_test, b)

        # JIT testing
        self.checkScriptable(conv_under_test,
                             list(zip([qX], [result_reference])),
                             check_save_load=True)

        # Test from_float
        float_conv = torch.nn.Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros').float()
        float_conv.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_conv, inplace=True)
        float_conv(X.float())
        quantized_float_conv = torch.nn.Sequential(float_conv)
        torch.quantization.convert(quantized_float_conv, inplace=True)

        # Smoke test to make sure the module actually runs
        quantized_float_conv(qX)
        if use_bias:
            self.assertEqual(quantized_float_conv[0].bias(), float_conv.bias)
        # Smoke test extra_repr
        str(quantized_float_conv)
Ejemplo n.º 5
0
    def test_conv_api(self, Q, padH, padW, sH, sW, dH, dW, prepacked):
        """Tests the correctness of the conv functional.

        The correctness is defined by the behavior being similar to the
        `quantized._ops` implementation.
        """
        # Random iunputs
        X, (scale, zero_point), (qmin, qmax), torch_type = Q
        (inputs, filters, bias, groups) = X

        iC, oC = inputs.shape[1], filters.shape[0]

        iH, iW = inputs.shape[2:]
        kH, kW = filters.shape[2:]
        assume(kH // 2 >= padH)
        assume(kW // 2 >= padW)
        oH = _conv_output_shape(iH, kH, padH, sH, dH)
        assume(oH > 0)
        oW = _conv_output_shape(iW, kW, padW, sW, dW)
        assume(oW > 0)

        inputs = torch.from_numpy(inputs).to(torch.float)
        filters = torch.from_numpy(filters).to(torch.float)
        bias = torch.from_numpy(bias).to(torch.float)

        kernel_size = (kH, kW)
        stride = (sH, sW)
        i_padding = (padH, padW)
        dilation = (dH, dW)

        # Quantized inputs
        i_NHWC = inputs.permute([0, 2, 3, 1]).contiguous()
        w_RSCK = filters.permute([0, 2, 3, 1]).contiguous()

        q_inputs = torch.quantize_linear(i_NHWC, scale, zero_point,
                                         torch.quint8)
        q_filters = torch.quantize_linear(w_RSCK, scale, zero_point,
                                          torch.qint8)
        q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(
            q_filters, groups)
        q_bias = torch.quantize_linear(bias, scale, zero_point, torch.qint32)

        # Reference op
        ref_op = torch.ops.quantized.fbgemm_conv2d

        # Results check
        try:
            ref_result = ref_op(q_inputs, q_filters_ref, q_bias, stride,
                                i_padding, dilation, groups, scale, zero_point)
        except RuntimeError as e:
            e_msg = str(e).split("\n")[0].split("(")[0].strip()
            np.testing.assert_raises_regex(type(e),
                                           e_msg,
                                           qF.conv2d,
                                           q_inputs,
                                           q_filters_ref,
                                           bias=q_bias,
                                           scale=scale,
                                           zero_point=zero_point,
                                           stride=stride,
                                           padding=i_padding,
                                           dilation=dilation,
                                           groups=groups,
                                           prepacked=True,
                                           dtype=torch_type)
        else:
            if prepacked:
                q_filters = torch.ops.quantized.fbgemm_conv_prepack(
                    q_filters, groups)
            q_result = qF.conv2d(q_inputs,
                                 q_filters,
                                 bias=q_bias,
                                 scale=scale,
                                 zero_point=zero_point,
                                 stride=stride,
                                 padding=i_padding,
                                 dilation=dilation,
                                 groups=groups,
                                 prepacked=prepacked,
                                 dtype=torch_type)

            np.testing.assert_equal(ref_result.int_repr().numpy(),
                                    q_result.int_repr().numpy())
Ejemplo n.º 6
0
    def test_conv_api(self, use_bias, use_fused):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        X = X.permute([0, 2, 3, 1]).contiguous()
        qX = torch.quantize_linear(X,
                                   scale=scale,
                                   zero_point=128,
                                   dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        qw = torch.quantize_linear(w,
                                   scale=scale,
                                   zero_point=0,
                                   dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None
        qb = torch.quantize_linear(
            b, scale=1.0 /
            1024, zero_point=0, dtype=torch.qint32) if use_bias else None

        if use_fused:
            conv_under_test = ConvReLU2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros')
        else:
            conv_under_test = Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros')
        conv_under_test.weight = qw
        conv_under_test.bias = qb
        conv_under_test.scale = torch.tensor([scale], dtype=torch.double)
        conv_under_test.zero_point = torch.tensor([zero_point],
                                                  dtype=torch.long)

        # Test members
        self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(conv_under_test, 'scale'))
        self.assertTrue(hasattr(conv_under_test, 'zero_point'))

        # Test properties
        self.assertEqual(qw, conv_under_test.weight)
        self.assertEqual(qb, conv_under_test.bias)
        self.assertEqual(scale, conv_under_test.scale)
        self.assertEqual(zero_point, conv_under_test.zero_point)

        # Test forward
        result_under_test = conv_under_test(qX)
        result_reference = qF.conv2d(qX,
                                     qw,
                                     bias=qb,
                                     scale=scale,
                                     zero_point=zero_point,
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     dtype=torch.quint8)
        if use_fused:
            # result_reference < zero_point doesn't work for qtensor yet
            # result_reference[result_reference < zero_point] = zero_point
            MB, OC, OH, OW = result_reference.size()
            for i in range(MB):
                for j in range(OC):
                    for h in range(OH):
                        for w in range(OW):
                            if result_reference[i][j][h][w].int_repr(
                            ) < zero_point:
                                # assign 0. that gets converted to zero_point
                                result_reference[i][j][h][w] = 0.

        self.assertEqual(result_reference,
                         result_under_test,
                         message="Tensors are not equal.")
Ejemplo n.º 7
0
    def test_conv_api(self, X, padH, padW, sH, sW, dH, dW):
        """Tests the correctness of the conv functional.

        The correctness is defined by the behavior being similar to the
        `quantized._ops` implementation.
        """
        # Random inputs
        # X, (scale, zero_point, torch_type) = X
        (inputs, filters, bias, groups) = X
        inputs, (inputs_scale, inputs_zero_point, inputs_qtype) = inputs
        filters, (filters_scale, filters_zero_point, filters_qtype) = filters
        bias, (bias_scale, bias_zero_point, bias_qtype) = bias

        scale, zero_point = inputs_scale, inputs_zero_point
        torch_type = inputs_qtype

        iC, oC = inputs.shape[1], filters.shape[0]

        iH, iW = inputs.shape[2:]
        kH, kW = filters.shape[2:]
        assume(kH // 2 >= padH)
        assume(kW // 2 >= padW)
        oH = _conv_output_shape(iH, kH, padH, sH, dH)
        assume(oH > 0)
        oW = _conv_output_shape(iW, kW, padW, sW, dW)
        assume(oW > 0)

        inputs = torch.from_numpy(inputs).to(torch.float)
        filters = torch.from_numpy(filters).to(torch.float)
        bias = torch.from_numpy(bias).to(torch.float)

        kernel_size = (kH, kW)
        stride = (sH, sW)
        i_padding = (padH, padW)
        dilation = (dH, dW)

        # Quantized inputs
        q_inputs = torch.quantize_linear(inputs, inputs_scale,
                                         inputs_zero_point, inputs_qtype)
        q_filters = torch.quantize_linear(filters, filters_scale,
                                          filters_zero_point, filters_qtype)
        q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(q_filters.permute([0, 2, 3, 1]), groups)
        q_bias = torch.quantize_linear(bias, bias_scale, bias_zero_point,
                                       bias_qtype)

        # Reference op
        ref_op = torch.ops.quantized.fbgemm_conv2d

        # Results check
        try:
            ref_result = ref_op(q_inputs.permute([0, 2, 3, 1]), q_filters_ref,
                                q_bias, stride,
                                i_padding, dilation,
                                groups, scale, zero_point).permute([0, 3, 1, 2])
        except RuntimeError as e:
            e_msg = str(e).split("\n")[0].split("(")[0].strip()
            np.testing.assert_raises_regex(
                type(e), e_msg, qF.conv2d,
                q_inputs, q_filters, bias=q_bias,
                scale=scale, zero_point=zero_point,
                stride=stride, padding=i_padding, dilation=dilation,
                groups=groups, dtype=torch_type)
        else:
            q_result = qF.conv2d(q_inputs,
                                 q_filters,
                                 bias=q_bias, scale=scale,
                                 zero_point=zero_point,
                                 stride=stride, padding=i_padding,
                                 dilation=dilation, groups=groups,
                                 dtype=torch_type)

            self.assertEqual(ref_result, q_result)