示例#1
0
class DynamicModuleAPITest(QuantizationTestCase):
    @no_deadline
    @unittest.skipIf(
        not torch.fbgemm_is_cpu_supported(),
        " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
        " with instruction set support avx2 or newer.",
    )
    @given(
        batch_size=st.integers(1, 5),
        in_features=st.integers(16, 32),
        out_features=st.integers(4, 8),
        use_bias=st.booleans(),
        use_default_observer=st.booleans(),
    )
    def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_default_observer):
        """test API functionality for nn.quantized.dynamic.Linear"""
        W = torch.rand(out_features, in_features).float()
        W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8)
        W_q = torch.quantize_linear(W, W_scale, W_zp, torch.qint8)
        X = torch.rand(batch_size, in_features).float()
        B = torch.rand(out_features).float() if use_bias else None
        qlinear = nnqd.Linear(in_features, out_features)
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X)
        qlinear.set_weight(W_q)

        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q)
        W_pack = qlinear._packed_weight
        qlinear.bias = B if use_bias else None
        Z_dq = qlinear(X)

        # Check if the module implementation matches calling the
        # ops directly
        Z_ref = torch.ops.quantized.fbgemm_linear_dynamic(X, W_pack, B)
        self.assertEqual(Z_ref, Z_dq)

        # Test serialization of dynamic quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        self.assertEqual(model_dict['weight'], W_q)
        if use_bias:
            self.assertEqual(model_dict['bias'], B)
        with tempfile.TemporaryFile() as f:
            torch.save(model_dict, f)
            f.seek(0)
            loaded_dict = torch.load(f)
        for key in model_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])
        loaded_qlinear = nnqd.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.fbgemm_linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_weight),
                         linear_unpack(loaded_qlinear._packed_weight))
        if use_bias:
            self.assertEqual(qlinear.bias, loaded_qlinear.bias)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_weight'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_weight'))
        self.assertTrue(hasattr(qlinear, 'weight'))
        self.assertTrue(hasattr(loaded_qlinear, 'weight'))

        self.assertEqual(qlinear.weight(), loaded_qlinear.weight())
        self.assertEqual(
            qlinear.weight(),
            torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight))
        Z_dq2 = qlinear(X)
        self.assertEqual(Z_dq, Z_dq2)

        # test serialization of module directly
        with tempfile.TemporaryFile() as f:
            torch.save(qlinear, f)
            f.seek(0)
            loaded = torch.load(f)
        # This check is disabled pending an issue in PyTorch serialization:
        # https://github.com/pytorch/pytorch/issues/24045
        # self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test JIT
        self.checkScriptable(qlinear,
                             list(zip([X], [Z_ref])),
                             check_save_load=True)

        # Test from_float
        float_linear = torch.nn.Linear(in_features, out_features).float()
        if use_default_observer:
            float_linear.qconfig = torch.quantization.default_dynamic_qconfig
        prepare_dynamic(float_linear)
        float_linear(X.float())
        quantized_float_linear = nnqd.Linear.from_float(float_linear)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X)

        # Smoke test extra_repr
        str(quantized_float_linear)
示例#2
0
class ModuleAPITest(QuantizationTestCase):
    def test_relu(self):
        relu_module = nnq.ReLU()
        relu6_module = nnq.ReLU6()

        x = torch.arange(-10, 10, dtype=torch.float)
        y_ref = torch.relu(x)
        y6_ref = torch.nn.modules.ReLU6()(x)

        qx = torch.quantize_linear(x, 1.0, 0, dtype=torch.qint32)
        qy = relu_module(qx)
        qy6 = relu6_module(qx)

        self.assertEqual(y_ref,
                         qy.dequantize(),
                         message="ReLU module API failed")
        self.assertEqual(y6_ref,
                         qy6.dequantize(),
                         message="ReLU6 module API failed")

    @no_deadline
    @unittest.skipIf(
        not torch.fbgemm_is_cpu_supported(),
        " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
        " with instruction set support avx2 or newer.",
    )
    @given(
        batch_size=st.integers(1, 5),
        in_features=st.integers(16, 32),
        out_features=st.integers(4, 8),
        use_bias=st.booleans(),
        use_fused=st.booleans(),
    )
    def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_fused):
        """test API functionality for nn.quantized.linear and nn._intrinsic.quantized.linear_relu"""
        W = torch.rand(out_features, in_features).float()
        W_q = torch.quantize_linear(W, 0.1, 4, torch.qint8)
        X = torch.rand(batch_size, in_features).float()
        X_q = torch.quantize_linear(X, 0.2, 10, torch.quint8)
        B = torch.rand(out_features).float() if use_bias else None
        B_q = torch.quantize_linear(B,
                                    W_q.q_scale() * X_q.q_scale(), 0,
                                    torch.qint32) if use_bias else None
        scale = 0.5
        zero_point = 3
        if use_fused:
            qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            qlinear = nnq.Linear(in_features, out_features)

        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X_q)

        qlinear.set_weight(W_q)
        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q)
        W_pack = qlinear._packed_weight
        qlinear.bias = B_q if use_bias else None

        qlinear.scale = float(scale)
        qlinear.zero_point = int(zero_point)
        Z_q = qlinear(X_q)
        # Check if the module implementation matches calling the
        # ops directly
        if use_fused:
            Z_ref = torch.ops.quantized.fbgemm_linear_relu(
                X_q, W_pack, B_q, scale, zero_point)
        else:
            Z_ref = torch.ops.quantized.fbgemm_linear(X_q, W_pack, B_q, scale,
                                                      zero_point)
        self.assertEqual(Z_ref, Z_q)

        # Test serialization of quantized Linear Module using state_dict

        model_dict = qlinear.state_dict()
        self.assertEqual(model_dict['weight'], W_q)
        if use_bias:
            self.assertEqual(model_dict['bias'], B_q)
        with tempfile.TemporaryFile() as f:
            torch.save(model_dict, f)
            f.seek(0)
            loaded_dict = torch.load(f)
        for key in model_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])
        if use_fused:
            loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            loaded_qlinear = nnq.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.fbgemm_linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_weight),
                         linear_unpack(loaded_qlinear._packed_weight))
        if use_bias:
            self.assertEqual(qlinear.bias, loaded_qlinear.bias)
        self.assertEqual(qlinear.scale, loaded_qlinear.scale)
        self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_weight'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_weight'))
        self.assertTrue(hasattr(qlinear, 'weight'))
        self.assertTrue(hasattr(loaded_qlinear, 'weight'))
        self.assertEqual(qlinear.weight(), loaded_qlinear.weight())
        self.assertEqual(
            qlinear.weight(),
            torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight))
        Z_q2 = loaded_qlinear(X_q)
        self.assertEqual(Z_q, Z_q2)

        # test serialization of module directly
        with tempfile.TemporaryFile() as f:
            torch.save(qlinear, f)
            f.seek(0)
            loaded = torch.load(f)
        # This check is disabled pending an issue in PyTorch serialization:
        # https://github.com/pytorch/pytorch/issues/24045
        # self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.bias, loaded.bias)
        self.assertEqual(qlinear.scale, loaded.scale)
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test JIT
        self.checkScriptable(qlinear,
                             list(zip([X_q], [Z_ref])),
                             check_save_load=True)

        # Test from_float.
        float_linear = torch.nn.Linear(in_features, out_features).float()
        float_linear.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_linear)
        float_linear(X.float())
        # Sequential allows swapping using "convert".
        quantized_float_linear = torch.nn.Sequential(float_linear)
        torch.quantization.convert(quantized_float_linear)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X_q)

        # Smoke test extra_repr
        str(quantized_float_linear)

    def test_quant_dequant_api(self):
        r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float)
        scale, zero_point, dtype = 1.0, 2, torch.qint8
        # testing Quantize API
        qr = torch.quantize_linear(r, scale, zero_point, dtype)
        quant_m = nnq.Quantize(scale, zero_point, dtype)
        qr2 = quant_m(r)
        self.assertEqual(qr, qr2)
        # testing Dequantize API
        rqr = qr.dequantize()
        dequant_m = nnq.DeQuantize()
        rqr2 = dequant_m(qr2)
        self.assertEqual(rqr, rqr2)

    @no_deadline
    @unittest.skipIf(
        not torch.fbgemm_is_cpu_supported(),
        " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
        " with instruction set support avx2 or newer.",
    )
    @given(
        use_bias=st.booleans(),
        use_fused=st.booleans(),
    )
    def test_conv_api(self, use_bias, use_fused):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        X = X.permute([0, 2, 3, 1]).contiguous()
        qX = torch.quantize_linear(X,
                                   scale=scale,
                                   zero_point=128,
                                   dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        qw = torch.quantize_linear(w,
                                   scale=scale,
                                   zero_point=0,
                                   dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None
        qb = torch.quantize_linear(
            b, scale=1.0 /
            1024, zero_point=0, dtype=torch.qint32) if use_bias else None

        if use_fused:
            conv_under_test = ConvReLU2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros')
        else:
            conv_under_test = Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros')
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        conv_under_test(qX)

        conv_under_test.set_weight(qw)
        conv_under_test.bias = qb
        conv_under_test.scale = scale
        conv_under_test.zero_point = zero_point

        # Test members
        self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(conv_under_test, 'scale'))
        self.assertTrue(hasattr(conv_under_test, 'zero_point'))

        # Test properties
        self.assertEqual(qw, conv_under_test.weight())
        self.assertEqual(qb, conv_under_test.bias)
        self.assertEqual(scale, conv_under_test.scale)
        self.assertEqual(zero_point, conv_under_test.zero_point)

        # Test forward
        result_under_test = conv_under_test(qX)
        result_reference = qF.conv2d(qX,
                                     qw,
                                     bias=qb,
                                     scale=scale,
                                     zero_point=zero_point,
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     dtype=torch.quint8)
        if use_fused:
            # result_reference < zero_point doesn't work for qtensor yet
            # result_reference[result_reference < zero_point] = zero_point
            MB, OC, OH, OW = result_reference.size()
            for i in range(MB):
                for j in range(OC):
                    for h in range(OH):
                        for w in range(OW):
                            if result_reference[i][j][h][w].int_repr(
                            ) < zero_point:
                                # assign 0. that gets converted to zero_point
                                result_reference[i][j][h][w] = 0.

        self.assertEqual(result_reference,
                         result_under_test,
                         message="Tensors are not equal.")

        # Test serialization of quantized Conv Module using state_dict
        model_dict = conv_under_test.state_dict()
        self.assertEqual(model_dict['weight'], qw)
        if use_bias:
            self.assertEqual(model_dict['bias'], qb)
        with tempfile.NamedTemporaryFile() as f:
            torch.save(model_dict, f)
            f.seek(0)
            loaded_dict = torch.load(f)
        for key in model_dict:
            self.assertEqual(loaded_dict[key], model_dict[key])
        if use_fused:
            loaded_conv_under_test = ConvReLU2d(in_channels=iC,
                                                out_channels=oC,
                                                kernel_size=(kH, kW),
                                                stride=1,
                                                padding=0,
                                                dilation=1,
                                                groups=g,
                                                bias=use_bias,
                                                padding_mode='zeros')
        else:
            loaded_conv_under_test = Conv2d(in_channels=iC,
                                            out_channels=oC,
                                            kernel_size=(kH, kW),
                                            stride=1,
                                            padding=0,
                                            dilation=1,
                                            groups=g,
                                            bias=use_bias,
                                            padding_mode='zeros')
        loaded_conv_under_test.load_state_dict(loaded_dict)
        self.assertEqual(loaded_conv_under_test.weight(),
                         conv_under_test.weight())
        if use_bias:
            self.assertEqual(loaded_conv_under_test.bias, conv_under_test.bias)
        self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale)
        self.assertEqual(loaded_conv_under_test.zero_point,
                         conv_under_test.zero_point)
        self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test))
        self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(loaded_conv_under_test, '_packed_weight'))
        self.assertTrue(hasattr(conv_under_test, 'weight'))
        self.assertTrue(hasattr(loaded_conv_under_test, 'weight'))
        self.assertEqual(loaded_conv_under_test.weight(),
                         conv_under_test.weight())
        self.assertEqual(loaded_conv_under_test.weight(), qw)
        loaded_result = loaded_conv_under_test(qX)
        self.assertEqual(loaded_result, result_reference)

        with tempfile.NamedTemporaryFile() as f:
            torch.save(conv_under_test, f)
            f.seek(0)
            loaded_conv = torch.load(f)

        self.assertEqual(conv_under_test.bias, loaded_conv.bias)
        self.assertEqual(conv_under_test.scale, loaded_conv.scale)
        self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point)

        # JIT testing
        self.checkScriptable(conv_under_test,
                             list(zip([qX], [result_reference])),
                             check_save_load=True)

        # Test from_float
        float_conv = torch.nn.Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros').float()
        float_conv.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_conv)
        float_conv(X.float())
        quantized_float_conv = torch.nn.Sequential(float_conv)
        torch.quantization.convert(quantized_float_conv)

        # Smoke test to make sure the module actually runs
        quantized_float_conv(qX)
        # Check that bias is quantized based on output scale
        if use_bias:
            qbias = torch.quantize_linear(
                float_conv.bias, quantized_float_conv[0].scale / 2**16, 0,
                torch.qint32)
            self.assertEqual(quantized_float_conv[0].bias.dequantize(),
                             qbias.dequantize())
        # Smoke test extra_repr
        str(quantized_float_conv)

    def test_pool_api(self):
        """Tests the correctness of the pool module.

        The correctness is defined against the functional implementation.
        """
        N, C, H, W = 10, 10, 10, 3
        kwargs = {
            'kernel_size': 2,
            'stride': None,
            'padding': 0,
            'dilation': 1
        }

        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, C, H, W, dtype=torch.float32)
        qX = torch.quantize_linear(X,
                                   scale=scale,
                                   zero_point=zero_point,
                                   dtype=torch.quint8)
        qX_expect = torch.nn.functional.max_pool2d(qX, **kwargs)

        pool_under_test = torch.nn.quantized.MaxPool2d(**kwargs)
        qX_hat = pool_under_test(qX)
        self.assertEqual(qX_expect, qX_hat)

        # JIT Testing
        self.checkScriptable(pool_under_test, list(zip([X], [qX_expect])))
示例#3
0
class FunctionalAPITest(QuantizationTestCase):
    def test_relu_api(self):
        X = torch.arange(-5, 5, dtype=torch.float)
        scale = 2.0
        zero_point = 1
        qX = torch.quantize_linear(X,
                                   scale=scale,
                                   zero_point=zero_point,
                                   dtype=torch.quint8)
        qY = torch.relu(qX)
        qY_hat = qF.relu(qX)
        self.assertEqual(qY, qY_hat)

    @no_deadline
    @unittest.skipIf(
        not torch.fbgemm_is_cpu_supported(),
        " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
        " with instruction set support avx2 or newer.",
    )
    @given(
        use_bias=st.booleans(), )
    def test_conv_api(self, use_bias):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128
        stride = (1, 1)
        i_padding = (0, 0)
        dilation = (1, 1)

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        X = X.permute([0, 2, 3, 1]).contiguous()
        qX = torch.quantize_linear(X,
                                   scale=scale,
                                   zero_point=128,
                                   dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        qw = torch.quantize_linear(w,
                                   scale=scale,
                                   zero_point=0,
                                   dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None
        q_bias = torch.quantize_linear(
            b, scale=1.0 /
            1024, zero_point=0, dtype=torch.qint32) if use_bias else None
        q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(
            qw.permute([0, 2, 3, 1]), stride, i_padding, dilation, g)

        requantized_bias = torch.quantize_linear(
            q_bias.dequantize(), scale *
            scale, 0, torch.qint32) if use_bias else None
        ref_result = torch.ops.quantized.fbgemm_conv2d(
            qX.permute([0, 2, 3, 1]), q_filters_ref, requantized_bias, stride,
            i_padding, dilation, g, scale, zero_point).permute([0, 3, 1, 2])

        q_result = torch.nn.quantized.functional.conv2d(qX,
                                                        qw,
                                                        bias=q_bias,
                                                        scale=scale,
                                                        zero_point=zero_point,
                                                        stride=stride,
                                                        padding=i_padding,
                                                        dilation=dilation,
                                                        groups=g,
                                                        dtype=torch.quint8)

        self.assertEqual(ref_result, q_result)
class DynamicModuleAPITest(QuantizationTestCase):
    @no_deadline
    @unittest.skipIf(
        not torch.fbgemm_is_cpu_supported(),
        " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
        " with instruction set support avx2 or newer.",
    )
    @given(
        batch_size=st.integers(1, 5),
        in_features=st.integers(16, 32),
        out_features=st.integers(4, 8),
        use_bias=st.booleans(),
        use_default_observer=st.booleans(),
    )
    def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_default_observer):
        """test API functionality for nn.quantized.dynamic.Linear"""
        W = torch.rand(out_features, in_features).float()
        W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8)
        W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8)
        X = torch.rand(batch_size, in_features).float()
        B = torch.rand(out_features).float() if use_bias else None
        qlinear = nnqd.Linear(in_features, out_features)
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear.set_weight_bias(W_q, B)
        qlinear(X)

        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q)
        W_pack = qlinear._packed_params
        Z_dq = qlinear(X)

        # Check if the module implementation matches calling the
        # ops directly
        Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack)
        self.assertEqual(Z_ref, Z_dq)

        # Test serialization of dynamic quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        self.assertEqual(model_dict['weight'], W_q)
        if use_bias:
            self.assertEqual(model_dict['bias'], B)
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])
        loaded_qlinear = nnqd.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_params),
                         linear_unpack(loaded_qlinear._packed_params))
        if use_bias:
            self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_params'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
        self.assertTrue(hasattr(qlinear, '_weight_bias'))
        self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))

        self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
        self.assertEqual(
            qlinear._weight_bias(),
            torch.ops.quantized.linear_unpack(qlinear._packed_params))
        Z_dq2 = qlinear(X)
        self.assertEqual(Z_dq, Z_dq2)

        # The below check is meant to ensure that `torch.save` and `torch.load`
        # serialization works, however it is currently broken by the following:
        # https://github.com/pytorch/pytorch/issues/24045
        #
        # Instead, we currently check that the proper exception is thrown on save.
        # <start code>
        # b = io.BytesIO()
        # torch.save(qlinear, b)
        # b.seek(0)
        # loaded = torch.load(b)
        # self.assertEqual(qlinear.weight(), loaded.weight())
        # self.assertEqual(qlinear.zero_point, loaded.zero_point)
        # <end code>
        with self.assertRaisesRegex(
                RuntimeError, r'torch.save\(\) is not currently supported'):
            b = io.BytesIO()
            torch.save(qlinear, b)

        # Test JIT
        self.checkScriptable(qlinear,
                             list(zip([X], [Z_ref])),
                             check_save_load=True)

        # Test from_float
        float_linear = torch.nn.Linear(in_features, out_features).float()
        if use_default_observer:
            float_linear.qconfig = torch.quantization.default_dynamic_qconfig
        prepare_dynamic(float_linear)
        float_linear(X.float())
        quantized_float_linear = nnqd.Linear.from_float(float_linear)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X)

        # Smoke test extra_repr
        str(quantized_float_linear)
示例#5
0
    ModForWrapping, \
    test_only_eval_fn, test_only_train_fn, \
    prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \
    TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel

from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \
    AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel

from hypothesis import given
from hypothesis import strategies as st
from hypothesis_utils import no_deadline
import io
import copy

@unittest.skipIf(
    not torch.fbgemm_is_cpu_supported(),
    " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
    " with instruction set support avx2 or newer.",
)
class PostTrainingQuantTest(QuantizationTestCase):
    def test_single_layer(self):
        r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
        to nnq.Linear which is the quantized version of the module
        """
        model = SingleLayerLinearModel()
        prepare(model)
        # Check if observers and quant/dequant nodes are inserted
        self.checkNoPrepModules(model)
        self.checkHasPrepModules(model.fc1)
        self.checkObservers(model)
示例#6
0
        qa = torch.quantize_linear(a, scale=scale, zero_point=zero_point,
                                   dtype=torch_type)

        a_hat = qa.dequantize()
        a_pool = F.max_pool2d(a_hat, kernel_size=k, stride=s, padding=p,
                              dilation=d)

        qa_pool_hat = q_max_pool(qa, kernel_size=k, stride=s, padding=p,
                                 dilation=d)
        a_pool_hat = qa_pool_hat.dequantize()

        np.testing.assert_equal(a_pool.numpy(), a_pool_hat.numpy())


@unittest.skipIf(
    TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
    " Quantized Linear requires FBGEMM. FBGEMM does not play"
    " well with UBSAN at the moment, so we skip the test if"
    " we are in a UBSAN environment.",
)
class TestQuantizedLinear(unittest.TestCase):
    """Tests the correctness of the quantized::fbgemm_linear op."""

    def test_qlinear(self):
        qlinear_prepack = torch.ops.quantized.fbgemm_linear_prepack
        qlinear = torch.ops.quantized.fbgemm_linear

        batch_size = 4
        input_channels = 16
        output_channels = 8
示例#7
0
        np.testing.assert_equal(cat_ref.numpy(), cat_q_out.numpy())

        # Test the cat on per-channel quantized tensor.
        ch_axis = 1
        scales = torch.from_numpy(np.array([1.0] * X.shape[ch_axis]))
        scales = scales.to(torch.float64)
        zero_points = torch.from_numpy(np.array([0] * X.shape[ch_axis]))
        zero_points = zero_points.to(torch.long)
        tensors_q[0] = torch.quantize_linear_per_channel(
            X, scales, zero_points, axis=[ch_axis], dtype=torch_type)
        with self.assertRaisesRegex(RuntimeError, "supported.*cat"):
            cat_q = q_cat_op(tensors_q, axis=axis, scale=scale,
                             zero_point=zero_point)

@unittest.skipIf(
    not torch.fbgemm_is_cpu_supported(),
    " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
    " with instruction set support avx2 or newer.",
)
class TestQuantizedLinear(unittest.TestCase):
    """Tests the correctness of the quantized linear and linear_relu op."""
    @given(batch_size=st.integers(1, 4),
           input_channels=st.integers(16, 32),
           output_channels=st.integers(4, 8),
           use_bias=st.booleans(),
           use_relu=st.booleans())
    def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, use_relu):
        qlinear_prepack = torch.ops.quantized.fbgemm_linear_prepack
        if use_relu:
            qlinear = torch.ops.quantized.fbgemm_linear_relu
        else:
示例#8
0
import unittest
import torch
import torch.nn.quantized as nnq
from torch.quantization import \
    quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules

from common_utils import run_tests, TEST_WITH_UBSAN
from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
    SkipQuantModel, QuantStubModel, \
    ModForFusion, ManualLinearQATModel, ManualConvLinearQATModel, test_only_eval_fn, test_only_train_fn

from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \
    AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel


@unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
                 'Quantization requires FBGEMM. FBGEMM does not play'
                 ' well with UBSAN at the moment, so we skip the test if'
                 ' we are in a UBSAN environment.')
class PostTrainingQuantTest(QuantizationTestCase):
    def test_single_layer(self):
        r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
        to nnq.Linear which is the quantized version of the module
        """
        model = SingleLayerLinearModel()
        model = prepare(model)
        # Check if observers and quant/dequant nodes are inserted
        self.checkNoPrepModules(model)
        self.checkHasPrepModules(model.fc1)
        self.checkObservers(model)