Ejemplo n.º 1
0
    def test_conv1d_api(self, batch_size, in_channels_per_group, length,
                        out_channels_per_group, groups, kernel, stride, pad,
                        dilation, X_scale, X_zero_point, W_scale, W_zero_point,
                        Y_scale, Y_zero_point, use_bias, use_fused,
                        use_channelwise):
        # Tests the correctness of the conv2d module.
        in_channels = in_channels_per_group * groups
        out_channels = out_channels_per_group * groups
        input_feature_map_size = (length, )
        kernel_size = (kernel, )
        stride = (stride, )
        pad = (pad, )
        dilation = (dilation, )
        if torch.backends.quantized.engine == 'qnnpack':
            use_channelwise = False
        if use_fused:
            module_name = "QuantizedConvReLU1d"
            qconv_module = nnq_fused.ConvReLU1d(in_channels,
                                                out_channels,
                                                kernel,
                                                stride,
                                                pad,
                                                dilation,
                                                groups,
                                                use_bias,
                                                padding_mode="zeros")
        else:
            module_name = "QuantizedConv1d"
            qconv_module = nnq.Conv1d(in_channels,
                                      out_channels,
                                      kernel,
                                      stride,
                                      pad,
                                      dilation,
                                      groups,
                                      use_bias,
                                      padding_mode="zeros")

        conv_module = nn.Conv1d(in_channels,
                                out_channels,
                                kernel,
                                stride,
                                pad,
                                dilation,
                                groups,
                                use_bias,
                                padding_mode="zeros")
        if use_fused:
            relu_module = nn.ReLU()
            conv_module = nni.ConvReLU1d(conv_module, relu_module)
        conv_module = conv_module.float()

        self._test_conv_api_impl(module_name, qconv_module, conv_module,
                                 batch_size, in_channels_per_group,
                                 input_feature_map_size,
                                 out_channels_per_group, groups, kernel_size,
                                 stride, pad, dilation, X_scale, X_zero_point,
                                 W_scale, W_zero_point, Y_scale, Y_zero_point,
                                 use_bias, use_fused, use_channelwise)
Ejemplo n.º 2
0
    def test_conv1d_api(self):
        options = itertools.product(
            ["zeros", "reflect"],  # pad_mode
            [True, False],  # use_bias
            [True, False],  # use_fused
            [True, False],  # use_channelwise
        )
        for pad_mode, use_bias, use_fused, use_channelwise in options:
            if torch.backends.quantized.engine == "qnnpack":
                use_channelwise = False
            batch_size = 2
            in_channels_per_group = 2
            length = 8
            out_channels_per_group = 2
            groups = 3
            kernel = 3
            stride = 2
            pad = 1
            dilation = 1
            # Tests the correctness of the conv2d module.
            in_channels = in_channels_per_group * groups
            out_channels = out_channels_per_group * groups
            input_feature_map_size = (length,)
            kernel_size = (kernel, )
            stride = (stride, )
            pad = (pad, )
            dilation = (dilation, )
            X_scale = 1.3
            X_zero_point = 2
            W_scale = [0.5]
            W_zero_point = [3]
            Y_scale = 5.0
            Y_zero_point = 4
            if torch.backends.quantized.engine == 'qnnpack':
                use_channelwise = False
            # use_fused -> quantized class
            class_map = {
                True: (nniq.ConvReLU1d, "QuantizedConvReLU1d"),
                False: (nnq.Conv1d, "QuantizedConv1d")
            }

            qconv_cls, module_name = class_map[use_fused]
            qconv_module = qconv_cls(
                in_channels, out_channels, kernel, stride, pad,
                dilation, groups, use_bias, padding_mode=pad_mode
            )

            conv_module = nn.Conv1d(
                in_channels, out_channels, kernel, stride, pad,
                dilation, groups, use_bias, padding_mode=pad_mode)
            if use_fused:
                relu_module = nn.ReLU()
                conv_module = nni.ConvReLU1d(conv_module, relu_module)
            conv_module = conv_module.float()

            self._test_conv_api_impl(
                module_name, qconv_module, conv_module, batch_size,
                in_channels_per_group, input_feature_map_size,
                out_channels_per_group, groups, kernel_size, stride, pad, pad_mode,
                dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
                Y_zero_point, use_bias, use_fused, use_channelwise)