def test_weight_only_activation_only_fakequant(self): for qengine in ["fbgemm", "qnnpack"]: if qengine not in torch.backends.quantized.supported_engines: continue if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: continue with override_quantized_engine(qengine): torch.manual_seed(67) calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) qconfigset = set([torch.quantization.default_weight_only_quant_qconfig, torch.quantization.default_activation_only_quant_qconfig]) SQNRTarget = [35, 45] for idx, qconfig in enumerate(qconfigset): my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) my_model.eval() out_ref = my_model(eval_data) fq_model = torch.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = qconfig torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) fq_model(calib_data) fq_model.apply(torch.quantization.enable_fake_quant) fq_model.apply(torch.quantization.disable_observer) out_fq = fq_model(eval_data) SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float')
def test_float_quant_compare_per_tensor(self): for qengine in ["fbgemm", "qnnpack"]: if qengine not in torch.backends.quantized.supported_engines: continue if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: continue with override_quantized_engine(qengine): torch.manual_seed(42) my_model = ModelMultipleOps().to(torch.float32) my_model.eval() calib_data = torch.rand(1024, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(1, 3, 15, 15, dtype=torch.float32) out_ref = my_model(eval_data) qModel = torch.quantization.QuantWrapper(my_model) qModel.eval() qModel.qconfig = torch.quantization.default_qconfig torch.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare(qModel, inplace=True) qModel(calib_data) torch.quantization.convert(qModel, inplace=True) out_q = qModel(eval_data) SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q)) # Quantized model output should be close to floating point model output numerically # Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired # output self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB')
def test_fake_quant_true_quant_compare(self): for qengine in ["fbgemm", "qnnpack"]: if qengine not in torch.backends.quantized.supported_engines: continue if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: continue with override_quantized_engine(qengine): torch.manual_seed(67) my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) my_model.eval() out_ref = my_model(eval_data) fq_model = torch.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = torch.quantization.default_qat_qconfig torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) fq_model(calib_data) fq_model.apply(torch.quantization.enable_fake_quant) fq_model.apply(torch.quantization.disable_observer) out_fq = fq_model(eval_data) SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) # Quantized model output should be close to floating point model output numerically # Setting target SQNR to be 35 dB self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB') torch.quantization.convert(fq_model) out_q = fq_model(eval_data) SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)) self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB')
def test_conv3d_api( self, batch_size, in_channels_per_group, D, H, W, out_channels_per_group, groups, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, qengine, ): # Tests the correctness of the conv3d function. # Currently conv3d only supports FbGemm engine if qengine not in torch.backends.quantized.supported_engines: return input_feature_map_size = (D, H, W) kernel_size = (kernel_d, kernel_h, kernel_w) stride = (stride_d, stride_h, stride_w) padding = (pad_d, pad_h, pad_w) dilation = (dilation, dilation, dilation) with override_quantized_engine(qengine): qconv_fn = qF.conv3d conv_fn = F.conv3d self._test_conv_api_impl( qconv_fn, conv_fn, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise)
def test_conv2d_api( self, batch_size, in_channels_per_group, H, W, out_channels_per_group, groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, qengine, ): # Tests the correctness of the conv2d function. if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return use_channelwise = False input_feature_map_size = (H, W) kernel_size = (kernel_h, kernel_w) stride = (stride_h, stride_w) padding = (pad_h, pad_w) dilation = (dilation, dilation) with override_quantized_engine(qengine): qconv_fn = qF.conv2d conv_fn = F.conv2d self._test_conv_api_impl( qconv_fn, conv_fn, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise)
def test_conv3d_api( self, batch_size, in_channels_per_group, D, H, W, out_channels_per_group, groups, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, use_fused, qengine, ): # Tests the correctness of the conv3d module. if qengine not in torch.backends.quantized.supported_engines: return in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups input_feature_map_size = (D, H, W) kernel_size = (kernel_d, kernel_h, kernel_w) stride = (stride_d, stride_h, stride_w) padding = (pad_d, pad_h, pad_w) dilation = (dilation, dilation, dilation) with override_quantized_engine(qengine): if use_fused: module_name = "QuantizedConvReLU3d" qconv_module = nnq_fused.ConvReLU3d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") else: module_name = "QuantizedConv3d" qconv_module = nnq.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") conv_module = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU3d(conv_module, relu_module) conv_module = conv_module.float() self._test_conv_api_impl( module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise)
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel, qengine): """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu""" if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return per_channel = False with override_quantized_engine(qengine): W = torch.rand(out_features, in_features).float() if per_channel: scale_tensor = torch.ones(out_features, dtype=torch.double) zero_point_tensor = torch.zeros(out_features, dtype=torch.long) for i in range(len(scale_tensor)): scale_tensor[i] = (i + 1.0) / 255.0 W_q = torch.quantize_per_channel(W, scales=scale_tensor, zero_points=zero_point_tensor, axis=0, dtype=torch.qint8) else: W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8) X = torch.rand(batch_size, in_features).float() X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8) B = torch.rand(out_features).float() if use_bias else None scale = 0.5 zero_point = 3 if use_fused: qlinear = nnq_fused.LinearReLU(in_features, out_features) else: qlinear = nnq.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X_q) qlinear.set_weight_bias(W_q, B) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_params._packed_params qlinear.scale = float(scale) qlinear.zero_point = int(zero_point) Z_q = qlinear(X_q) # Check if the module implementation matches calling the # ops directly if use_fused: Z_ref = torch.ops.quantized.linear_relu( X_q, W_pack, scale, zero_point) self.assertTrue('QuantizedLinearReLU' in str(qlinear)) else: Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) self.assertTrue('QuantizedLinear' in str(qlinear)) self.assertEqual(Z_ref, Z_q) # Test serialization of quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['_packed_params.weight'], W_q) if use_bias: self.assertEqual(model_dict['_packed_params.bias'], B) b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) if use_fused: loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features) else: loaded_qlinear = nnq.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.linear_unpack self.assertEqual( linear_unpack(qlinear._packed_params._packed_params), linear_unpack(loaded_qlinear._packed_params._packed_params)) if use_bias: self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_params')) self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) self.assertTrue(hasattr(qlinear, '_weight_bias')) self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) self.assertEqual( qlinear._weight_bias(), torch.ops.quantized.linear_unpack( qlinear._packed_params._packed_params)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on save. # <start code> # b = io.BytesIO() # torch.save(qlinear, b) # b.seek(0) # loaded = torch.load(b) # self.assertEqual(qlinear.weight(), loaded.weight()) # self.assertEqual(qlinear.scale, loaded.scale) # self.assertEqual(qlinear.zero_point, loaded.zero_point) # <end code> with self.assertRaisesRegex( RuntimeError, r'torch.save\(\) is not currently supported'): b = io.BytesIO() torch.save(qlinear, b) # Test JIT self.checkScriptable(qlinear, list(zip([X_q], [Z_ref])), check_save_load=True) # Test from_float. float_linear = torch.nn.Linear(in_features, out_features).float() float_linear.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_linear, inplace=True) float_linear(X.float()) # Sequential allows swapping using "convert". quantized_float_linear = torch.nn.Sequential(float_linear) quantized_float_linear = torch.quantization.convert( quantized_float_linear, inplace=True) # Smoke test to make sure the module actually runs quantized_float_linear(X_q) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
def test_conv_api(self, use_bias, use_fused, per_channel, qengine): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return per_channel = False with override_quantized_engine(qengine): N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) qX = torch.quantize_per_tensor(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) if per_channel: scale_tensor = torch.ones(oC, dtype=torch.double) zero_point_tensor = torch.zeros(oC, dtype=torch.long) for i in range(len(scale_tensor)): scale_tensor[i] = (i + 1.0) / 255.0 qw = torch.quantize_per_channel(w, scales=scale_tensor, zero_points=zero_point_tensor, axis=0, dtype=torch.qint8) else: qw = torch.quantize_per_tensor(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None if use_fused: conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') # Run module with default-initialized parameters. # This tests that the constructor is correct. conv_under_test.set_weight_bias(qw, b) conv_under_test(qX) conv_under_test.scale = scale conv_under_test.zero_point = zero_point # Test members self.assertTrue(hasattr(conv_under_test, '_packed_params')) self.assertTrue(hasattr(conv_under_test, 'scale')) self.assertTrue(hasattr(conv_under_test, 'zero_point')) # Test properties self.assertEqual(qw, conv_under_test.weight()) if use_bias: self.assertEqual(b, conv_under_test.bias()) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) result_reference = qF.conv2d(qX, qw, bias=b, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, dtype=torch.quint8) if use_fused: # result_reference < zero_point doesn't work for qtensor yet # result_reference[result_reference < zero_point] = zero_point MB, OC, OH, OW = result_reference.size() for i in range(MB): for j in range(OC): for h in range(OH): for w in range(OW): if result_reference[i][j][h][w].int_repr( ) < zero_point: # assign 0. that gets converted to zero_point result_reference[i][j][h][w] = 0. self.assertEqual(result_reference, result_under_test, message="Tensors are not equal.") # Test serialization of quantized Conv Module using state_dict model_dict = conv_under_test.state_dict() self.assertEqual(model_dict['weight'], qw) if use_bias: self.assertEqual(model_dict['bias'], b) b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(loaded_dict[key], model_dict[key]) if use_fused: loaded_conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') self.assertTrue( 'QuantizedConvReLU2d' in str(loaded_conv_under_test)) else: loaded_conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') self.assertTrue( 'QuantizedConv2d' in str(loaded_conv_under_test)) loaded_conv_under_test.load_state_dict(loaded_dict) self.assertEqual(loaded_conv_under_test._weight_bias(), conv_under_test._weight_bias()) if use_bias: self.assertEqual(loaded_conv_under_test.bias(), conv_under_test.bias()) self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale) self.assertEqual(loaded_conv_under_test.zero_point, conv_under_test.zero_point) self.assertTrue( dir(loaded_conv_under_test) == dir(conv_under_test)) self.assertTrue(hasattr(conv_under_test, '_packed_params')) self.assertTrue(hasattr(loaded_conv_under_test, '_packed_params')) self.assertTrue(hasattr(conv_under_test, '_weight_bias')) self.assertTrue(hasattr(loaded_conv_under_test, '_weight_bias')) self.assertEqual(loaded_conv_under_test._weight_bias(), conv_under_test._weight_bias()) self.assertEqual(loaded_conv_under_test.weight(), qw) loaded_result = loaded_conv_under_test(qX) self.assertEqual(loaded_result, result_reference) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on save. # <start code> # b = io.BytesIO() # torch.save(conv_under_test, b) # b.seek(0) # loaded_conv = torch.load(b) # # self.assertEqual(conv_under_test.bias(), loaded_conv.bias()) # self.assertEqual(conv_under_test.scale, loaded_conv.scale) # self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point) # <end code> with self.assertRaisesRegex( RuntimeError, r'torch.save\(\) is not currently supported'): b = io.BytesIO() torch.save(conv_under_test, b) # JIT testing self.checkScriptable(conv_under_test, list(zip([qX], [result_reference])), check_save_load=True) # Test from_float float_conv = torch.nn.Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros').float() float_conv.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_conv, inplace=True) float_conv(X.float()) quantized_float_conv = torch.nn.Sequential(float_conv) torch.quantization.convert(quantized_float_conv, inplace=True) # Smoke test to make sure the module actually runs quantized_float_conv(qX) if use_bias: self.assertEqual(quantized_float_conv[0].bias(), float_conv.bias) # Smoke test extra_repr self.assertTrue('QuantizedConv2d' in str(quantized_float_conv))
def test_conv_api(self, use_bias, per_channel, qengine): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return per_channel = False with override_quantized_engine(qengine): N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 stride = (1, 1) i_padding = (0, 0) dilation = (1, 1) X = torch.randn(N, iC, H, W, dtype=torch.float32) qX = torch.quantize_per_tensor(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) if per_channel: scale_tensor = torch.ones(oC, dtype=torch.double) zero_point_tensor = torch.zeros(oC, dtype=torch.long) for i in range(len(scale_tensor)): scale_tensor[i] = (i + 1.0) / 255.0 qw = torch.quantize_per_channel(w, scales=scale_tensor, zero_points=zero_point_tensor, axis=0, dtype=torch.qint8) else: qw = torch.quantize_per_tensor(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None q_filters_ref = torch.ops.quantized.conv_prepack( qw, b, stride, i_padding, dilation, g) ref_result = torch.ops.quantized.conv2d(qX, q_filters_ref, stride, i_padding, dilation, g, scale, zero_point) q_result = torch.nn.quantized.functional.conv2d( qX, qw, bias=b, scale=scale, zero_point=zero_point, stride=stride, padding=i_padding, dilation=dilation, groups=g, dtype=torch.quint8) self.assertEqual(ref_result, q_result)