def test_conv_api(self, use_bias, use_fused): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) qX = torch.quantize_per_tensor(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) qw = torch.quantize_per_tensor(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None if use_fused: conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') # Run module with default-initialized parameters. # This tests that the constructor is correct. conv_under_test.set_weight_bias(qw, b) conv_under_test(qX) conv_under_test.scale = scale conv_under_test.zero_point = zero_point # Test members self.assertTrue(hasattr(conv_under_test, '_packed_params')) self.assertTrue(hasattr(conv_under_test, 'scale')) self.assertTrue(hasattr(conv_under_test, 'zero_point')) # Test properties self.assertEqual(qw, conv_under_test.weight()) self.assertEqual(b, conv_under_test.bias()) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) result_reference = qF.conv2d(qX, qw, bias=b, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, dtype=torch.quint8) if use_fused: # result_reference < zero_point doesn't work for qtensor yet # result_reference[result_reference < zero_point] = zero_point MB, OC, OH, OW = result_reference.size() for i in range(MB): for j in range(OC): for h in range(OH): for w in range(OW): if result_reference[i][j][h][w].int_repr( ) < zero_point: # assign 0. that gets converted to zero_point result_reference[i][j][h][w] = 0. self.assertEqual(result_reference, result_under_test, message="Tensors are not equal.") # Test serialization of quantized Conv Module using state_dict model_dict = conv_under_test.state_dict() self.assertEqual(model_dict['weight'], qw) if use_bias: self.assertEqual(model_dict['bias'], b) b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(loaded_dict[key], model_dict[key]) if use_fused: loaded_conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: loaded_conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') loaded_conv_under_test.load_state_dict(loaded_dict) self.assertEqual(loaded_conv_under_test._weight_bias(), conv_under_test._weight_bias()) if use_bias: self.assertEqual(loaded_conv_under_test.bias(), conv_under_test.bias()) self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale) self.assertEqual(loaded_conv_under_test.zero_point, conv_under_test.zero_point) self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test)) self.assertTrue(hasattr(conv_under_test, '_packed_params')) self.assertTrue(hasattr(loaded_conv_under_test, '_packed_params')) self.assertTrue(hasattr(conv_under_test, '_weight_bias')) self.assertTrue(hasattr(loaded_conv_under_test, '_weight_bias')) self.assertEqual(loaded_conv_under_test._weight_bias(), conv_under_test._weight_bias()) self.assertEqual(loaded_conv_under_test.weight(), qw) loaded_result = loaded_conv_under_test(qX) self.assertEqual(loaded_result, result_reference) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on save. # <start code> # b = io.BytesIO() # torch.save(conv_under_test, b) # b.seek(0) # loaded_conv = torch.load(b) # # self.assertEqual(conv_under_test.bias(), loaded_conv.bias()) # self.assertEqual(conv_under_test.scale, loaded_conv.scale) # self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point) # <end code> with self.assertRaisesRegex( RuntimeError, r'torch.save\(\) is not currently supported'): b = io.BytesIO() torch.save(conv_under_test, b) # JIT testing self.checkScriptable(conv_under_test, list(zip([qX], [result_reference])), check_save_load=True) # Test from_float float_conv = torch.nn.Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros').float() float_conv.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_conv, inplace=True) float_conv(X.float()) quantized_float_conv = torch.nn.Sequential(float_conv) torch.quantization.convert(quantized_float_conv, inplace=True) # Smoke test to make sure the module actually runs quantized_float_conv(qX) if use_bias: self.assertEqual(quantized_float_conv[0].bias(), float_conv.bias) # Smoke test extra_repr str(quantized_float_conv)
def test_conv_api(self, use_bias, use_fused): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) X = X.permute([0, 2, 3, 1]).contiguous() qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None qb = torch.quantize_linear( b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) if use_bias else None if use_fused: conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') # Run module with default-initialized parameters. # This tests that the constructor is correct. conv_under_test(qX) conv_under_test.set_weight(qw) conv_under_test.bias = qb conv_under_test.scale = scale conv_under_test.zero_point = zero_point # Test members self.assertTrue(hasattr(conv_under_test, '_packed_weight')) self.assertTrue(hasattr(conv_under_test, 'scale')) self.assertTrue(hasattr(conv_under_test, 'zero_point')) # Test properties self.assertEqual(qw, conv_under_test.weight()) self.assertEqual(qb, conv_under_test.bias) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) result_reference = qF.conv2d(qX, qw, bias=qb, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, dtype=torch.quint8) if use_fused: # result_reference < zero_point doesn't work for qtensor yet # result_reference[result_reference < zero_point] = zero_point MB, OC, OH, OW = result_reference.size() for i in range(MB): for j in range(OC): for h in range(OH): for w in range(OW): if result_reference[i][j][h][w].int_repr( ) < zero_point: # assign 0. that gets converted to zero_point result_reference[i][j][h][w] = 0. self.assertEqual(result_reference, result_under_test, message="Tensors are not equal.") # Test serialization of quantized Conv Module using state_dict model_dict = conv_under_test.state_dict() self.assertEqual(model_dict['weight'], qw) if use_bias: self.assertEqual(model_dict['bias'], qb) with tempfile.NamedTemporaryFile() as f: torch.save(model_dict, f) f.seek(0) loaded_dict = torch.load(f) for key in model_dict: self.assertEqual(loaded_dict[key], model_dict[key]) if use_fused: loaded_conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: loaded_conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') loaded_conv_under_test.load_state_dict(loaded_dict) self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight()) if use_bias: self.assertEqual(loaded_conv_under_test.bias, conv_under_test.bias) self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale) self.assertEqual(loaded_conv_under_test.zero_point, conv_under_test.zero_point) self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test)) self.assertTrue(hasattr(conv_under_test, '_packed_weight')) self.assertTrue(hasattr(loaded_conv_under_test, '_packed_weight')) self.assertTrue(hasattr(conv_under_test, 'weight')) self.assertTrue(hasattr(loaded_conv_under_test, 'weight')) self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight()) self.assertEqual(loaded_conv_under_test.weight(), qw) loaded_result = loaded_conv_under_test(qX) self.assertEqual(loaded_result, result_reference) with tempfile.NamedTemporaryFile() as f: torch.save(conv_under_test, f) f.seek(0) loaded_conv = torch.load(f) self.assertEqual(conv_under_test.bias, loaded_conv.bias) self.assertEqual(conv_under_test.scale, loaded_conv.scale) self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point) # JIT testing self.checkScriptable(conv_under_test, list(zip([qX], [result_reference])), check_save_load=True) # Test from_float float_conv = torch.nn.Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros').float() float_conv.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_conv) float_conv(X.float()) quantized_float_conv = torch.nn.Sequential(float_conv) torch.quantization.convert(quantized_float_conv) # Smoke test to make sure the module actually runs quantized_float_conv(qX) # Check that bias is quantized based on output scale if use_bias: qbias = torch.quantize_linear( float_conv.bias, quantized_float_conv[0].scale / 2**16, 0, torch.qint32) self.assertEqual(quantized_float_conv[0].bias.dequantize(), qbias.dequantize()) # Smoke test extra_repr str(quantized_float_conv)
def test_conv_api(self, use_bias, use_fused): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) X = X.permute([0, 2, 3, 1]).contiguous() qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None qb = torch.quantize_linear( b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) if use_bias else None if use_fused: conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') conv_under_test.weight = qw conv_under_test.bias = qb conv_under_test.scale = torch.tensor([scale], dtype=torch.double) conv_under_test.zero_point = torch.tensor([zero_point], dtype=torch.long) # Test members self.assertTrue(hasattr(conv_under_test, '_packed_weight')) self.assertTrue(hasattr(conv_under_test, 'scale')) self.assertTrue(hasattr(conv_under_test, 'zero_point')) # Test properties self.assertEqual(qw, conv_under_test.weight) self.assertEqual(qb, conv_under_test.bias) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) result_reference = qF.conv2d(qX, qw, bias=qb, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, dtype=torch.quint8) if use_fused: # result_reference < zero_point doesn't work for qtensor yet # result_reference[result_reference < zero_point] = zero_point MB, OC, OH, OW = result_reference.size() for i in range(MB): for j in range(OC): for h in range(OH): for w in range(OW): if result_reference[i][j][h][w].int_repr( ) < zero_point: # assign 0. that gets converted to zero_point result_reference[i][j][h][w] = 0. self.assertEqual(result_reference, result_under_test, message="Tensors are not equal.")