def test_conv3d_api(
        self, batch_size, in_channels_per_group, D, H, W,
        out_channels_per_group, groups, kernel_d, kernel_h, kernel_w,
        stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale,
        X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
        use_channelwise, use_fused, qengine,
    ):
        # Tests the correctness of the conv3d module.
        if qengine not in torch.backends.quantized.supported_engines:
            return

        in_channels = in_channels_per_group * groups
        out_channels = out_channels_per_group * groups
        input_feature_map_size = (D, H, W)
        kernel_size = (kernel_d, kernel_h, kernel_w)
        stride = (stride_d, stride_h, stride_w)
        padding = (pad_d, pad_h, pad_w)
        dilation = (dilation, dilation, dilation)

        with override_quantized_engine(qengine):
            if use_fused:
                module_name = "QuantizedConvReLU3d"
                qconv_module = nnq_fused.ConvReLU3d(
                    in_channels, out_channels, kernel_size, stride, padding,
                    dilation, groups, use_bias, padding_mode="zeros")
            else:
                module_name = "QuantizedConv3d"
                qconv_module = nnq.Conv3d(
                    in_channels, out_channels, kernel_size, stride, padding,
                    dilation, groups, use_bias, padding_mode="zeros")

            conv_module = nn.Conv3d(
                in_channels, out_channels, kernel_size, stride, padding,
                dilation, groups, use_bias, padding_mode="zeros")
            if use_fused:
                relu_module = nn.ReLU()
                conv_module = nni.ConvReLU3d(conv_module, relu_module)
            conv_module = conv_module.float()

            self._test_conv_api_impl(
                module_name, qconv_module, conv_module, batch_size,
                in_channels_per_group, input_feature_map_size,
                out_channels_per_group, groups, kernel_size, stride, padding,
                dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
                Y_zero_point, use_bias, use_fused, use_channelwise)
    def test_conv3d_api(self):
        options = itertools.product(
            [True, False],  # use_bias
            [True, False],  # use_fused
            [True, False],  # use_channelwise
        )
        for use_bias, use_fused, use_channelwise in options:
            if torch.backends.quantized.engine == "qnnpack":
                use_channelwise = False
            batch_size = 2
            in_channels_per_group = 2
            H = 8
            W = 8
            D = 8
            out_channels_per_group = 2
            groups = 3
            kernel_h = 3
            kernel_w = 3
            kernel_d = 3
            stride_h = 2
            stride_w = 2
            stride_d = 2
            pad_mode = "zeros"  # 3d doesn't support reflect padding
            pad_h = 1
            pad_w = 1
            pad_d = 1
            dilation = 1
            # Tests the correctness of the conv3d module.
            in_channels = in_channels_per_group * groups
            out_channels = out_channels_per_group * groups
            input_feature_map_size = (D, H, W)
            kernel_size = (kernel_d, kernel_h, kernel_w)
            stride = (stride_d, stride_h, stride_w)
            padding = (pad_d, pad_h, pad_w)
            dilation = (dilation, dilation, dilation)
            X_scale = 1.3
            X_zero_point = 2
            W_scale = [0.5]
            W_zero_point = [3]
            Y_scale = 5.0
            Y_zero_point = 4
            # use_fused -> quantized class
            class_map = {
                True: (nniq.ConvReLU3d, "QuantizedConvReLU3d"),
                False: (nnq.Conv3d, "QuantizedConv3d")
            }

            with override_quantized_engine('fbgemm'):
                qconv_cls, module_name = class_map[use_fused]
                qconv_module = qconv_cls(
                    in_channels, out_channels, kernel_size, stride, padding,
                    dilation, groups, use_bias, padding_mode=pad_mode
                )

                conv_module = nn.Conv3d(
                    in_channels, out_channels, kernel_size, stride, padding,
                    dilation, groups, use_bias, padding_mode=pad_mode)
                if use_fused:
                    relu_module = nn.ReLU()
                    conv_module = nni.ConvReLU3d(conv_module, relu_module)
                conv_module = conv_module.float()

                self._test_conv_api_impl(
                    module_name, qconv_module, conv_module, batch_size,
                    in_channels_per_group, input_feature_map_size,
                    out_channels_per_group, groups, kernel_size, stride, padding,
                    pad_mode, dilation, X_scale, X_zero_point, W_scale,
                    W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused,
                    use_channelwise)