class DynamicModuleAPITest(QuantizationTestCase): @no_deadline @unittest.skipIf( not torch.fbgemm_is_cpu_supported(), " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.", ) @given( batch_size=st.integers(1, 5), in_features=st.integers(16, 32), out_features=st.integers(4, 8), use_bias=st.booleans(), use_default_observer=st.booleans(), ) def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer): """test API functionality for nn.quantized.dynamic.Linear""" W = torch.rand(out_features, in_features).float() W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8) W_q = torch.quantize_linear(W, W_scale, W_zp, torch.qint8) X = torch.rand(batch_size, in_features).float() B = torch.rand(out_features).float() if use_bias else None qlinear = nnqd.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X) qlinear.set_weight(W_q) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_weight qlinear.bias = B if use_bias else None Z_dq = qlinear(X) # Check if the module implementation matches calling the # ops directly Z_ref = torch.ops.quantized.fbgemm_linear_dynamic(X, W_pack, B) self.assertEqual(Z_ref, Z_dq) # Test serialization of dynamic quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['weight'], W_q) if use_bias: self.assertEqual(model_dict['bias'], B) with tempfile.TemporaryFile() as f: torch.save(model_dict, f) f.seek(0) loaded_dict = torch.load(f) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qlinear = nnqd.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.fbgemm_linear_unpack self.assertEqual(linear_unpack(qlinear._packed_weight), linear_unpack(loaded_qlinear._packed_weight)) if use_bias: self.assertEqual(qlinear.bias, loaded_qlinear.bias) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_weight')) self.assertTrue(hasattr(loaded_qlinear, '_packed_weight')) self.assertTrue(hasattr(qlinear, 'weight')) self.assertTrue(hasattr(loaded_qlinear, 'weight')) self.assertEqual(qlinear.weight(), loaded_qlinear.weight()) self.assertEqual( qlinear.weight(), torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight)) Z_dq2 = qlinear(X) self.assertEqual(Z_dq, Z_dq2) # test serialization of module directly with tempfile.TemporaryFile() as f: torch.save(qlinear, f) f.seek(0) loaded = torch.load(f) # This check is disabled pending an issue in PyTorch serialization: # https://github.com/pytorch/pytorch/issues/24045 # self.assertEqual(qlinear.weight(), loaded.weight()) self.assertEqual(qlinear.zero_point, loaded.zero_point) # Test JIT self.checkScriptable(qlinear, list(zip([X], [Z_ref])), check_save_load=True) # Test from_float float_linear = torch.nn.Linear(in_features, out_features).float() if use_default_observer: float_linear.qconfig = torch.quantization.default_dynamic_qconfig prepare_dynamic(float_linear) float_linear(X.float()) quantized_float_linear = nnqd.Linear.from_float(float_linear) # Smoke test to make sure the module actually runs quantized_float_linear(X) # Smoke test extra_repr str(quantized_float_linear)
class ModuleAPITest(QuantizationTestCase): def test_relu(self): relu_module = nnq.ReLU() relu6_module = nnq.ReLU6() x = torch.arange(-10, 10, dtype=torch.float) y_ref = torch.relu(x) y6_ref = torch.nn.modules.ReLU6()(x) qx = torch.quantize_linear(x, 1.0, 0, dtype=torch.qint32) qy = relu_module(qx) qy6 = relu6_module(qx) self.assertEqual(y_ref, qy.dequantize(), message="ReLU module API failed") self.assertEqual(y6_ref, qy6.dequantize(), message="ReLU6 module API failed") @no_deadline @unittest.skipIf( not torch.fbgemm_is_cpu_supported(), " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.", ) @given( batch_size=st.integers(1, 5), in_features=st.integers(16, 32), out_features=st.integers(4, 8), use_bias=st.booleans(), use_fused=st.booleans(), ) def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused): """test API functionality for nn.quantized.linear and nn._intrinsic.quantized.linear_relu""" W = torch.rand(out_features, in_features).float() W_q = torch.quantize_linear(W, 0.1, 4, torch.qint8) X = torch.rand(batch_size, in_features).float() X_q = torch.quantize_linear(X, 0.2, 10, torch.quint8) B = torch.rand(out_features).float() if use_bias else None B_q = torch.quantize_linear(B, W_q.q_scale() * X_q.q_scale(), 0, torch.qint32) if use_bias else None scale = 0.5 zero_point = 3 if use_fused: qlinear = nnq_fused.LinearReLU(in_features, out_features) else: qlinear = nnq.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X_q) qlinear.set_weight(W_q) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_weight qlinear.bias = B_q if use_bias else None qlinear.scale = float(scale) qlinear.zero_point = int(zero_point) Z_q = qlinear(X_q) # Check if the module implementation matches calling the # ops directly if use_fused: Z_ref = torch.ops.quantized.fbgemm_linear_relu( X_q, W_pack, B_q, scale, zero_point) else: Z_ref = torch.ops.quantized.fbgemm_linear(X_q, W_pack, B_q, scale, zero_point) self.assertEqual(Z_ref, Z_q) # Test serialization of quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['weight'], W_q) if use_bias: self.assertEqual(model_dict['bias'], B_q) with tempfile.TemporaryFile() as f: torch.save(model_dict, f) f.seek(0) loaded_dict = torch.load(f) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) if use_fused: loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features) else: loaded_qlinear = nnq.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.fbgemm_linear_unpack self.assertEqual(linear_unpack(qlinear._packed_weight), linear_unpack(loaded_qlinear._packed_weight)) if use_bias: self.assertEqual(qlinear.bias, loaded_qlinear.bias) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_weight')) self.assertTrue(hasattr(loaded_qlinear, '_packed_weight')) self.assertTrue(hasattr(qlinear, 'weight')) self.assertTrue(hasattr(loaded_qlinear, 'weight')) self.assertEqual(qlinear.weight(), loaded_qlinear.weight()) self.assertEqual( qlinear.weight(), torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) # test serialization of module directly with tempfile.TemporaryFile() as f: torch.save(qlinear, f) f.seek(0) loaded = torch.load(f) # This check is disabled pending an issue in PyTorch serialization: # https://github.com/pytorch/pytorch/issues/24045 # self.assertEqual(qlinear.weight(), loaded.weight()) self.assertEqual(qlinear.bias, loaded.bias) self.assertEqual(qlinear.scale, loaded.scale) self.assertEqual(qlinear.zero_point, loaded.zero_point) # Test JIT self.checkScriptable(qlinear, list(zip([X_q], [Z_ref])), check_save_load=True) # Test from_float. float_linear = torch.nn.Linear(in_features, out_features).float() float_linear.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_linear) float_linear(X.float()) # Sequential allows swapping using "convert". quantized_float_linear = torch.nn.Sequential(float_linear) torch.quantization.convert(quantized_float_linear) # Smoke test to make sure the module actually runs quantized_float_linear(X_q) # Smoke test extra_repr str(quantized_float_linear) def test_quant_dequant_api(self): r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float) scale, zero_point, dtype = 1.0, 2, torch.qint8 # testing Quantize API qr = torch.quantize_linear(r, scale, zero_point, dtype) quant_m = nnq.Quantize(scale, zero_point, dtype) qr2 = quant_m(r) self.assertEqual(qr, qr2) # testing Dequantize API rqr = qr.dequantize() dequant_m = nnq.DeQuantize() rqr2 = dequant_m(qr2) self.assertEqual(rqr, rqr2) @no_deadline @unittest.skipIf( not torch.fbgemm_is_cpu_supported(), " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.", ) @given( use_bias=st.booleans(), use_fused=st.booleans(), ) def test_conv_api(self, use_bias, use_fused): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) X = X.permute([0, 2, 3, 1]).contiguous() qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None qb = torch.quantize_linear( b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) if use_bias else None if use_fused: conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') # Run module with default-initialized parameters. # This tests that the constructor is correct. conv_under_test(qX) conv_under_test.set_weight(qw) conv_under_test.bias = qb conv_under_test.scale = scale conv_under_test.zero_point = zero_point # Test members self.assertTrue(hasattr(conv_under_test, '_packed_weight')) self.assertTrue(hasattr(conv_under_test, 'scale')) self.assertTrue(hasattr(conv_under_test, 'zero_point')) # Test properties self.assertEqual(qw, conv_under_test.weight()) self.assertEqual(qb, conv_under_test.bias) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) result_reference = qF.conv2d(qX, qw, bias=qb, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, dtype=torch.quint8) if use_fused: # result_reference < zero_point doesn't work for qtensor yet # result_reference[result_reference < zero_point] = zero_point MB, OC, OH, OW = result_reference.size() for i in range(MB): for j in range(OC): for h in range(OH): for w in range(OW): if result_reference[i][j][h][w].int_repr( ) < zero_point: # assign 0. that gets converted to zero_point result_reference[i][j][h][w] = 0. self.assertEqual(result_reference, result_under_test, message="Tensors are not equal.") # Test serialization of quantized Conv Module using state_dict model_dict = conv_under_test.state_dict() self.assertEqual(model_dict['weight'], qw) if use_bias: self.assertEqual(model_dict['bias'], qb) with tempfile.NamedTemporaryFile() as f: torch.save(model_dict, f) f.seek(0) loaded_dict = torch.load(f) for key in model_dict: self.assertEqual(loaded_dict[key], model_dict[key]) if use_fused: loaded_conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: loaded_conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') loaded_conv_under_test.load_state_dict(loaded_dict) self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight()) if use_bias: self.assertEqual(loaded_conv_under_test.bias, conv_under_test.bias) self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale) self.assertEqual(loaded_conv_under_test.zero_point, conv_under_test.zero_point) self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test)) self.assertTrue(hasattr(conv_under_test, '_packed_weight')) self.assertTrue(hasattr(loaded_conv_under_test, '_packed_weight')) self.assertTrue(hasattr(conv_under_test, 'weight')) self.assertTrue(hasattr(loaded_conv_under_test, 'weight')) self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight()) self.assertEqual(loaded_conv_under_test.weight(), qw) loaded_result = loaded_conv_under_test(qX) self.assertEqual(loaded_result, result_reference) with tempfile.NamedTemporaryFile() as f: torch.save(conv_under_test, f) f.seek(0) loaded_conv = torch.load(f) self.assertEqual(conv_under_test.bias, loaded_conv.bias) self.assertEqual(conv_under_test.scale, loaded_conv.scale) self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point) # JIT testing self.checkScriptable(conv_under_test, list(zip([qX], [result_reference])), check_save_load=True) # Test from_float float_conv = torch.nn.Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros').float() float_conv.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_conv) float_conv(X.float()) quantized_float_conv = torch.nn.Sequential(float_conv) torch.quantization.convert(quantized_float_conv) # Smoke test to make sure the module actually runs quantized_float_conv(qX) # Check that bias is quantized based on output scale if use_bias: qbias = torch.quantize_linear( float_conv.bias, quantized_float_conv[0].scale / 2**16, 0, torch.qint32) self.assertEqual(quantized_float_conv[0].bias.dequantize(), qbias.dequantize()) # Smoke test extra_repr str(quantized_float_conv) def test_pool_api(self): """Tests the correctness of the pool module. The correctness is defined against the functional implementation. """ N, C, H, W = 10, 10, 10, 3 kwargs = { 'kernel_size': 2, 'stride': None, 'padding': 0, 'dilation': 1 } scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, C, H, W, dtype=torch.float32) qX = torch.quantize_linear(X, scale=scale, zero_point=zero_point, dtype=torch.quint8) qX_expect = torch.nn.functional.max_pool2d(qX, **kwargs) pool_under_test = torch.nn.quantized.MaxPool2d(**kwargs) qX_hat = pool_under_test(qX) self.assertEqual(qX_expect, qX_hat) # JIT Testing self.checkScriptable(pool_under_test, list(zip([X], [qX_expect])))
class FunctionalAPITest(QuantizationTestCase): def test_relu_api(self): X = torch.arange(-5, 5, dtype=torch.float) scale = 2.0 zero_point = 1 qX = torch.quantize_linear(X, scale=scale, zero_point=zero_point, dtype=torch.quint8) qY = torch.relu(qX) qY_hat = qF.relu(qX) self.assertEqual(qY, qY_hat) @no_deadline @unittest.skipIf( not torch.fbgemm_is_cpu_supported(), " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.", ) @given( use_bias=st.booleans(), ) def test_conv_api(self, use_bias): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 stride = (1, 1) i_padding = (0, 0) dilation = (1, 1) X = torch.randn(N, iC, H, W, dtype=torch.float32) X = X.permute([0, 2, 3, 1]).contiguous() qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None q_bias = torch.quantize_linear( b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) if use_bias else None q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack( qw.permute([0, 2, 3, 1]), stride, i_padding, dilation, g) requantized_bias = torch.quantize_linear( q_bias.dequantize(), scale * scale, 0, torch.qint32) if use_bias else None ref_result = torch.ops.quantized.fbgemm_conv2d( qX.permute([0, 2, 3, 1]), q_filters_ref, requantized_bias, stride, i_padding, dilation, g, scale, zero_point).permute([0, 3, 1, 2]) q_result = torch.nn.quantized.functional.conv2d(qX, qw, bias=q_bias, scale=scale, zero_point=zero_point, stride=stride, padding=i_padding, dilation=dilation, groups=g, dtype=torch.quint8) self.assertEqual(ref_result, q_result)
class DynamicModuleAPITest(QuantizationTestCase): @no_deadline @unittest.skipIf( not torch.fbgemm_is_cpu_supported(), " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.", ) @given( batch_size=st.integers(1, 5), in_features=st.integers(16, 32), out_features=st.integers(4, 8), use_bias=st.booleans(), use_default_observer=st.booleans(), ) def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer): """test API functionality for nn.quantized.dynamic.Linear""" W = torch.rand(out_features, in_features).float() W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8) W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8) X = torch.rand(batch_size, in_features).float() B = torch.rand(out_features).float() if use_bias else None qlinear = nnqd.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear.set_weight_bias(W_q, B) qlinear(X) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_params Z_dq = qlinear(X) # Check if the module implementation matches calling the # ops directly Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack) self.assertEqual(Z_ref, Z_dq) # Test serialization of dynamic quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['weight'], W_q) if use_bias: self.assertEqual(model_dict['bias'], B) b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qlinear = nnqd.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.linear_unpack self.assertEqual(linear_unpack(qlinear._packed_params), linear_unpack(loaded_qlinear._packed_params)) if use_bias: self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_params')) self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) self.assertTrue(hasattr(qlinear, '_weight_bias')) self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) self.assertEqual( qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params)) Z_dq2 = qlinear(X) self.assertEqual(Z_dq, Z_dq2) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on save. # <start code> # b = io.BytesIO() # torch.save(qlinear, b) # b.seek(0) # loaded = torch.load(b) # self.assertEqual(qlinear.weight(), loaded.weight()) # self.assertEqual(qlinear.zero_point, loaded.zero_point) # <end code> with self.assertRaisesRegex( RuntimeError, r'torch.save\(\) is not currently supported'): b = io.BytesIO() torch.save(qlinear, b) # Test JIT self.checkScriptable(qlinear, list(zip([X], [Z_ref])), check_save_load=True) # Test from_float float_linear = torch.nn.Linear(in_features, out_features).float() if use_default_observer: float_linear.qconfig = torch.quantization.default_dynamic_qconfig prepare_dynamic(float_linear) float_linear(X.float()) quantized_float_linear = nnqd.Linear.from_float(float_linear) # Smoke test to make sure the module actually runs quantized_float_linear(X) # Smoke test extra_repr str(quantized_float_linear)
ModForWrapping, \ test_only_eval_fn, test_only_train_fn, \ prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \ TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \ AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel from hypothesis import given from hypothesis import strategies as st from hypothesis_utils import no_deadline import io import copy @unittest.skipIf( not torch.fbgemm_is_cpu_supported(), " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.", ) class PostTrainingQuantTest(QuantizationTestCase): def test_single_layer(self): r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped to nnq.Linear which is the quantized version of the module """ model = SingleLayerLinearModel() prepare(model) # Check if observers and quant/dequant nodes are inserted self.checkNoPrepModules(model) self.checkHasPrepModules(model.fc1) self.checkObservers(model)
qa = torch.quantize_linear(a, scale=scale, zero_point=zero_point, dtype=torch_type) a_hat = qa.dequantize() a_pool = F.max_pool2d(a_hat, kernel_size=k, stride=s, padding=p, dilation=d) qa_pool_hat = q_max_pool(qa, kernel_size=k, stride=s, padding=p, dilation=d) a_pool_hat = qa_pool_hat.dequantize() np.testing.assert_equal(a_pool.numpy(), a_pool_hat.numpy()) @unittest.skipIf( TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(), " Quantized Linear requires FBGEMM. FBGEMM does not play" " well with UBSAN at the moment, so we skip the test if" " we are in a UBSAN environment.", ) class TestQuantizedLinear(unittest.TestCase): """Tests the correctness of the quantized::fbgemm_linear op.""" def test_qlinear(self): qlinear_prepack = torch.ops.quantized.fbgemm_linear_prepack qlinear = torch.ops.quantized.fbgemm_linear batch_size = 4 input_channels = 16 output_channels = 8
np.testing.assert_equal(cat_ref.numpy(), cat_q_out.numpy()) # Test the cat on per-channel quantized tensor. ch_axis = 1 scales = torch.from_numpy(np.array([1.0] * X.shape[ch_axis])) scales = scales.to(torch.float64) zero_points = torch.from_numpy(np.array([0] * X.shape[ch_axis])) zero_points = zero_points.to(torch.long) tensors_q[0] = torch.quantize_linear_per_channel( X, scales, zero_points, axis=[ch_axis], dtype=torch_type) with self.assertRaisesRegex(RuntimeError, "supported.*cat"): cat_q = q_cat_op(tensors_q, axis=axis, scale=scale, zero_point=zero_point) @unittest.skipIf( not torch.fbgemm_is_cpu_supported(), " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.", ) class TestQuantizedLinear(unittest.TestCase): """Tests the correctness of the quantized linear and linear_relu op.""" @given(batch_size=st.integers(1, 4), input_channels=st.integers(16, 32), output_channels=st.integers(4, 8), use_bias=st.booleans(), use_relu=st.booleans()) def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, use_relu): qlinear_prepack = torch.ops.quantized.fbgemm_linear_prepack if use_relu: qlinear = torch.ops.quantized.fbgemm_linear_relu else:
import unittest import torch import torch.nn.quantized as nnq from torch.quantization import \ quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules from common_utils import run_tests, TEST_WITH_UBSAN from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \ SkipQuantModel, QuantStubModel, \ ModForFusion, ManualLinearQATModel, ManualConvLinearQATModel, test_only_eval_fn, test_only_train_fn from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \ AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel @unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(), 'Quantization requires FBGEMM. FBGEMM does not play' ' well with UBSAN at the moment, so we skip the test if' ' we are in a UBSAN environment.') class PostTrainingQuantTest(QuantizationTestCase): def test_single_layer(self): r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped to nnq.Linear which is the quantized version of the module """ model = SingleLayerLinearModel() model = prepare(model) # Check if observers and quant/dequant nodes are inserted self.checkNoPrepModules(model) self.checkHasPrepModules(model.fc1) self.checkObservers(model)