def test_conv2d_api(self, batch_size, in_channels_per_group, H, W,
                        out_channels_per_group, groups, kernel_h, kernel_w,
                        stride_h, stride_w, pad_h, pad_w, dilation, X_scale,
                        X_zero_point, W_scale, W_zero_point, Y_scale,
                        Y_zero_point, use_bias, use_fused, use_channelwise):
        # Tests the correctness of the conv2d module.
        in_channels = in_channels_per_group * groups
        out_channels = out_channels_per_group * groups
        input_feature_map_size = (H, W)
        kernel_size = (kernel_h, kernel_w)
        stride = (stride_h, stride_w)
        padding = (pad_h, pad_w)
        dilation = (dilation, dilation)
        if torch.backends.quantized.engine == 'qnnpack':
            use_channelwise = False
        if use_fused:
            module_name = "QuantizedConvReLU2d"
            qconv_module = nnq_fused.ConvReLU2d(in_channels,
                                                out_channels,
                                                kernel_size,
                                                stride,
                                                padding,
                                                dilation,
                                                groups,
                                                use_bias,
                                                padding_mode="zeros")
        else:
            module_name = "QuantizedConv2d"
            qconv_module = nnq.Conv2d(in_channels,
                                      out_channels,
                                      kernel_size,
                                      stride,
                                      padding,
                                      dilation,
                                      groups,
                                      use_bias,
                                      padding_mode="zeros")

        conv_module = nn.Conv2d(in_channels,
                                out_channels,
                                kernel_size,
                                stride,
                                padding,
                                dilation,
                                groups,
                                use_bias,
                                padding_mode="zeros")
        if use_fused:
            relu_module = nn.ReLU()
            conv_module = nni.ConvReLU2d(conv_module, relu_module)
        conv_module = conv_module.float()

        self._test_conv_api_impl(
            module_name, qconv_module, conv_module, batch_size,
            in_channels_per_group, input_feature_map_size,
            out_channels_per_group, groups, kernel_size, stride, padding,
            dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
            Y_zero_point, use_bias, use_fused, use_channelwise)
    def test_conv2d_api(self):
        options = itertools.product(
            ["zeros", "reflect"],  # pad_mode
            [True, False],  # use_bias
            [True, False],  # use_fused
            [True, False],  # use_channelwise
        )
        for pad_mode, use_bias, use_fused, use_channelwise in options:
            if torch.backends.quantized.engine == "qnnpack":
                use_channelwise = False
            batch_size = 2
            in_channels_per_group = 2
            H = 8
            W = 8
            out_channels_per_group = 2
            groups = 3
            kernel_h = 3
            kernel_w = 3
            stride_h = 2
            stride_w = 2
            pad_h = 1
            pad_w = 1
            dilation = 1
            # Tests the correctness of the conv2d module.
            in_channels = in_channels_per_group * groups
            out_channels = out_channels_per_group * groups
            input_feature_map_size = (H, W)
            kernel_size = (kernel_h, kernel_w)
            stride = (stride_h, stride_w)
            padding = (pad_h, pad_w)
            dilation = (dilation, dilation)
            X_scale = 1.3
            X_zero_point = 2
            W_scale = [0.5]
            W_zero_point = [3]
            Y_scale = 5.0
            Y_zero_point = 4
            # use_fused -> quantized class
            class_map = {
                True: (nniq.ConvReLU2d, "QuantizedConvReLU2d"),
                False: (nnq.Conv2d, "QuantizedConv2d")
            }

            qconv_cls, module_name = class_map[use_fused]
            qconv_module = qconv_cls(
                in_channels, out_channels, kernel_size, stride, padding,
                dilation, groups, use_bias, padding_mode=pad_mode
            )

            conv_module = nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding,
                dilation, groups, use_bias, padding_mode=pad_mode)
            if use_fused:
                relu_module = nn.ReLU()
                conv_module = nni.ConvReLU2d(conv_module, relu_module)
            conv_module = conv_module.float()

            self._test_conv_api_impl(
                module_name, qconv_module, conv_module, batch_size,
                in_channels_per_group, input_feature_map_size,
                out_channels_per_group, groups, kernel_size, stride, padding,
                pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
                Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise)