예제 #1
0
    def test_conv_api(self, use_bias, use_fused):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        qX = torch.quantize_per_tensor(X,
                                       scale=scale,
                                       zero_point=128,
                                       dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        qw = torch.quantize_per_tensor(w,
                                       scale=scale,
                                       zero_point=0,
                                       dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None

        if use_fused:
            conv_under_test = ConvReLU2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros')
        else:
            conv_under_test = Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros')
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        conv_under_test.set_weight_bias(qw, b)
        conv_under_test(qX)

        conv_under_test.scale = scale
        conv_under_test.zero_point = zero_point

        # Test members
        self.assertTrue(hasattr(conv_under_test, '_packed_params'))
        self.assertTrue(hasattr(conv_under_test, 'scale'))
        self.assertTrue(hasattr(conv_under_test, 'zero_point'))

        # Test properties
        self.assertEqual(qw, conv_under_test.weight())
        self.assertEqual(b, conv_under_test.bias())
        self.assertEqual(scale, conv_under_test.scale)
        self.assertEqual(zero_point, conv_under_test.zero_point)

        # Test forward
        result_under_test = conv_under_test(qX)
        result_reference = qF.conv2d(qX,
                                     qw,
                                     bias=b,
                                     scale=scale,
                                     zero_point=zero_point,
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     dtype=torch.quint8)
        if use_fused:
            # result_reference < zero_point doesn't work for qtensor yet
            # result_reference[result_reference < zero_point] = zero_point
            MB, OC, OH, OW = result_reference.size()
            for i in range(MB):
                for j in range(OC):
                    for h in range(OH):
                        for w in range(OW):
                            if result_reference[i][j][h][w].int_repr(
                            ) < zero_point:
                                # assign 0. that gets converted to zero_point
                                result_reference[i][j][h][w] = 0.

        self.assertEqual(result_reference,
                         result_under_test,
                         message="Tensors are not equal.")

        # Test serialization of quantized Conv Module using state_dict
        model_dict = conv_under_test.state_dict()
        self.assertEqual(model_dict['weight'], qw)
        if use_bias:
            self.assertEqual(model_dict['bias'], b)
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            self.assertEqual(loaded_dict[key], model_dict[key])
        if use_fused:
            loaded_conv_under_test = ConvReLU2d(in_channels=iC,
                                                out_channels=oC,
                                                kernel_size=(kH, kW),
                                                stride=1,
                                                padding=0,
                                                dilation=1,
                                                groups=g,
                                                bias=use_bias,
                                                padding_mode='zeros')
        else:
            loaded_conv_under_test = Conv2d(in_channels=iC,
                                            out_channels=oC,
                                            kernel_size=(kH, kW),
                                            stride=1,
                                            padding=0,
                                            dilation=1,
                                            groups=g,
                                            bias=use_bias,
                                            padding_mode='zeros')
        loaded_conv_under_test.load_state_dict(loaded_dict)
        self.assertEqual(loaded_conv_under_test._weight_bias(),
                         conv_under_test._weight_bias())
        if use_bias:
            self.assertEqual(loaded_conv_under_test.bias(),
                             conv_under_test.bias())
        self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale)
        self.assertEqual(loaded_conv_under_test.zero_point,
                         conv_under_test.zero_point)
        self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test))
        self.assertTrue(hasattr(conv_under_test, '_packed_params'))
        self.assertTrue(hasattr(loaded_conv_under_test, '_packed_params'))
        self.assertTrue(hasattr(conv_under_test, '_weight_bias'))
        self.assertTrue(hasattr(loaded_conv_under_test, '_weight_bias'))
        self.assertEqual(loaded_conv_under_test._weight_bias(),
                         conv_under_test._weight_bias())
        self.assertEqual(loaded_conv_under_test.weight(), qw)
        loaded_result = loaded_conv_under_test(qX)
        self.assertEqual(loaded_result, result_reference)

        # The below check is meant to ensure that `torch.save` and `torch.load`
        # serialization works, however it is currently broken by the following:
        # https://github.com/pytorch/pytorch/issues/24045
        #
        # Instead, we currently check that the proper exception is thrown on save.
        # <start code>
        # b = io.BytesIO()
        # torch.save(conv_under_test, b)
        # b.seek(0)
        # loaded_conv = torch.load(b)
        #
        # self.assertEqual(conv_under_test.bias(), loaded_conv.bias())
        # self.assertEqual(conv_under_test.scale, loaded_conv.scale)
        # self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point)
        # <end code>
        with self.assertRaisesRegex(
                RuntimeError, r'torch.save\(\) is not currently supported'):
            b = io.BytesIO()
            torch.save(conv_under_test, b)

        # JIT testing
        self.checkScriptable(conv_under_test,
                             list(zip([qX], [result_reference])),
                             check_save_load=True)

        # Test from_float
        float_conv = torch.nn.Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros').float()
        float_conv.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_conv, inplace=True)
        float_conv(X.float())
        quantized_float_conv = torch.nn.Sequential(float_conv)
        torch.quantization.convert(quantized_float_conv, inplace=True)

        # Smoke test to make sure the module actually runs
        quantized_float_conv(qX)
        if use_bias:
            self.assertEqual(quantized_float_conv[0].bias(), float_conv.bias)
        # Smoke test extra_repr
        str(quantized_float_conv)
예제 #2
0
    def test_conv_api(self, use_bias, use_fused):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        X = X.permute([0, 2, 3, 1]).contiguous()
        qX = torch.quantize_linear(X,
                                   scale=scale,
                                   zero_point=128,
                                   dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        qw = torch.quantize_linear(w,
                                   scale=scale,
                                   zero_point=0,
                                   dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None
        qb = torch.quantize_linear(
            b, scale=1.0 /
            1024, zero_point=0, dtype=torch.qint32) if use_bias else None

        if use_fused:
            conv_under_test = ConvReLU2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros')
        else:
            conv_under_test = Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros')
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        conv_under_test(qX)

        conv_under_test.set_weight(qw)
        conv_under_test.bias = qb
        conv_under_test.scale = scale
        conv_under_test.zero_point = zero_point

        # Test members
        self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(conv_under_test, 'scale'))
        self.assertTrue(hasattr(conv_under_test, 'zero_point'))

        # Test properties
        self.assertEqual(qw, conv_under_test.weight())
        self.assertEqual(qb, conv_under_test.bias)
        self.assertEqual(scale, conv_under_test.scale)
        self.assertEqual(zero_point, conv_under_test.zero_point)

        # Test forward
        result_under_test = conv_under_test(qX)
        result_reference = qF.conv2d(qX,
                                     qw,
                                     bias=qb,
                                     scale=scale,
                                     zero_point=zero_point,
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     dtype=torch.quint8)
        if use_fused:
            # result_reference < zero_point doesn't work for qtensor yet
            # result_reference[result_reference < zero_point] = zero_point
            MB, OC, OH, OW = result_reference.size()
            for i in range(MB):
                for j in range(OC):
                    for h in range(OH):
                        for w in range(OW):
                            if result_reference[i][j][h][w].int_repr(
                            ) < zero_point:
                                # assign 0. that gets converted to zero_point
                                result_reference[i][j][h][w] = 0.

        self.assertEqual(result_reference,
                         result_under_test,
                         message="Tensors are not equal.")

        # Test serialization of quantized Conv Module using state_dict
        model_dict = conv_under_test.state_dict()
        self.assertEqual(model_dict['weight'], qw)
        if use_bias:
            self.assertEqual(model_dict['bias'], qb)
        with tempfile.NamedTemporaryFile() as f:
            torch.save(model_dict, f)
            f.seek(0)
            loaded_dict = torch.load(f)
        for key in model_dict:
            self.assertEqual(loaded_dict[key], model_dict[key])
        if use_fused:
            loaded_conv_under_test = ConvReLU2d(in_channels=iC,
                                                out_channels=oC,
                                                kernel_size=(kH, kW),
                                                stride=1,
                                                padding=0,
                                                dilation=1,
                                                groups=g,
                                                bias=use_bias,
                                                padding_mode='zeros')
        else:
            loaded_conv_under_test = Conv2d(in_channels=iC,
                                            out_channels=oC,
                                            kernel_size=(kH, kW),
                                            stride=1,
                                            padding=0,
                                            dilation=1,
                                            groups=g,
                                            bias=use_bias,
                                            padding_mode='zeros')
        loaded_conv_under_test.load_state_dict(loaded_dict)
        self.assertEqual(loaded_conv_under_test.weight(),
                         conv_under_test.weight())
        if use_bias:
            self.assertEqual(loaded_conv_under_test.bias, conv_under_test.bias)
        self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale)
        self.assertEqual(loaded_conv_under_test.zero_point,
                         conv_under_test.zero_point)
        self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test))
        self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(loaded_conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(conv_under_test, 'weight'))
        self.assertTrue(hasattr(loaded_conv_under_test, 'weight'))
        self.assertEqual(loaded_conv_under_test.weight(),
                         conv_under_test.weight())
        self.assertEqual(loaded_conv_under_test.weight(), qw)
        loaded_result = loaded_conv_under_test(qX)
        self.assertEqual(loaded_result, result_reference)

        with tempfile.NamedTemporaryFile() as f:
            torch.save(conv_under_test, f)
            f.seek(0)
            loaded_conv = torch.load(f)

        self.assertEqual(conv_under_test.bias, loaded_conv.bias)
        self.assertEqual(conv_under_test.scale, loaded_conv.scale)
        self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point)

        # JIT testing
        self.checkScriptable(conv_under_test,
                             list(zip([qX], [result_reference])),
                             check_save_load=True)

        # Test from_float
        float_conv = torch.nn.Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros').float()
        float_conv.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_conv)
        float_conv(X.float())
        quantized_float_conv = torch.nn.Sequential(float_conv)
        torch.quantization.convert(quantized_float_conv)

        # Smoke test to make sure the module actually runs
        quantized_float_conv(qX)
        # Check that bias is quantized based on output scale
        if use_bias:
            qbias = torch.quantize_linear(
                float_conv.bias, quantized_float_conv[0].scale / 2**16, 0,
                torch.qint32)
            self.assertEqual(quantized_float_conv[0].bias.dequantize(),
                             qbias.dequantize())
        # Smoke test extra_repr
        str(quantized_float_conv)
예제 #3
0
    def test_conv_api(self, use_bias, use_fused):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        X = X.permute([0, 2, 3, 1]).contiguous()
        qX = torch.quantize_linear(X,
                                   scale=scale,
                                   zero_point=128,
                                   dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        qw = torch.quantize_linear(w,
                                   scale=scale,
                                   zero_point=0,
                                   dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None
        qb = torch.quantize_linear(
            b, scale=1.0 /
            1024, zero_point=0, dtype=torch.qint32) if use_bias else None

        if use_fused:
            conv_under_test = ConvReLU2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros')
        else:
            conv_under_test = Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros')
        conv_under_test.weight = qw
        conv_under_test.bias = qb
        conv_under_test.scale = torch.tensor([scale], dtype=torch.double)
        conv_under_test.zero_point = torch.tensor([zero_point],
                                                  dtype=torch.long)

        # Test members
        self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(conv_under_test, 'scale'))
        self.assertTrue(hasattr(conv_under_test, 'zero_point'))

        # Test properties
        self.assertEqual(qw, conv_under_test.weight)
        self.assertEqual(qb, conv_under_test.bias)
        self.assertEqual(scale, conv_under_test.scale)
        self.assertEqual(zero_point, conv_under_test.zero_point)

        # Test forward
        result_under_test = conv_under_test(qX)
        result_reference = qF.conv2d(qX,
                                     qw,
                                     bias=qb,
                                     scale=scale,
                                     zero_point=zero_point,
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     dtype=torch.quint8)
        if use_fused:
            # result_reference < zero_point doesn't work for qtensor yet
            # result_reference[result_reference < zero_point] = zero_point
            MB, OC, OH, OW = result_reference.size()
            for i in range(MB):
                for j in range(OC):
                    for h in range(OH):
                        for w in range(OW):
                            if result_reference[i][j][h][w].int_repr(
                            ) < zero_point:
                                # assign 0. that gets converted to zero_point
                                result_reference[i][j][h][w] = 0.

        self.assertEqual(result_reference,
                         result_under_test,
                         message="Tensors are not equal.")