def test_conv_api(self): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) X = X.permute([0, 2, 3, 1]).contiguous() qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) w = w.permute([0, 2, 3, 1]).contiguous() qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) qb = torch.quantize_linear(b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=True, padding_mode='zeros') conv_under_test.weight = qw conv_under_test.bias = qb conv_under_test.scale = scale conv_under_test.zero_point = zero_point # Test members self.assertTrue(hasattr(conv_under_test, '_packed_weight')) self.assertTrue(hasattr(conv_under_test, '_scale')) self.assertTrue(hasattr(conv_under_test, '_zero_point')) # Test properties # self.assertEqual(qw, conv_under_test.weight) self.assertEqual(qb, conv_under_test.bias) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) result_reference = qF.conv2d(qX, qw, bias=qb, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, prepacked=False, dtype=torch.quint8) self.assertEqual(result_reference, result_under_test, message="Tensors are not equal.")
def forward(self, input): return qF.conv2d(input=input, weight=self._packed_weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, padding_mode=self.padding_mode, scale=self.scale, zero_point=self.zero_point, dtype=torch.quint8, prepacked=True)
def test_conv_api(self, use_bias, use_fused): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) X = X.permute([0, 2, 3, 1]).contiguous() qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None qb = torch.quantize_linear( b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) if use_bias else None if use_fused: conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') # Run module with default-initialized parameters. # This tests that the constructor is correct. conv_under_test(qX) conv_under_test.set_weight(qw) conv_under_test.bias = qb conv_under_test.scale = scale conv_under_test.zero_point = zero_point # Test members self.assertTrue(hasattr(conv_under_test, '_packed_weight')) self.assertTrue(hasattr(conv_under_test, 'scale')) self.assertTrue(hasattr(conv_under_test, 'zero_point')) # Test properties self.assertEqual(qw, conv_under_test.weight()) self.assertEqual(qb, conv_under_test.bias) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) result_reference = qF.conv2d(qX, qw, bias=qb, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, dtype=torch.quint8) if use_fused: # result_reference < zero_point doesn't work for qtensor yet # result_reference[result_reference < zero_point] = zero_point MB, OC, OH, OW = result_reference.size() for i in range(MB): for j in range(OC): for h in range(OH): for w in range(OW): if result_reference[i][j][h][w].int_repr( ) < zero_point: # assign 0. that gets converted to zero_point result_reference[i][j][h][w] = 0. self.assertEqual(result_reference, result_under_test, message="Tensors are not equal.") # Test serialization of quantized Conv Module using state_dict model_dict = conv_under_test.state_dict() self.assertEqual(model_dict['weight'], qw) if use_bias: self.assertEqual(model_dict['bias'], qb) with tempfile.NamedTemporaryFile() as f: torch.save(model_dict, f) f.seek(0) loaded_dict = torch.load(f) for key in model_dict: self.assertEqual(loaded_dict[key], model_dict[key]) if use_fused: loaded_conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: loaded_conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') loaded_conv_under_test.load_state_dict(loaded_dict) self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight()) if use_bias: self.assertEqual(loaded_conv_under_test.bias, conv_under_test.bias) self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale) self.assertEqual(loaded_conv_under_test.zero_point, conv_under_test.zero_point) self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test)) self.assertTrue(hasattr(conv_under_test, '_packed_weight')) self.assertTrue(hasattr(loaded_conv_under_test, '_packed_weight')) self.assertTrue(hasattr(conv_under_test, 'weight')) self.assertTrue(hasattr(loaded_conv_under_test, 'weight')) self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight()) self.assertEqual(loaded_conv_under_test.weight(), qw) loaded_result = loaded_conv_under_test(qX) self.assertEqual(loaded_result, result_reference) with tempfile.NamedTemporaryFile() as f: torch.save(conv_under_test, f) f.seek(0) loaded_conv = torch.load(f) self.assertEqual(conv_under_test.bias, loaded_conv.bias) self.assertEqual(conv_under_test.scale, loaded_conv.scale) self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point) # JIT testing self.checkScriptable(conv_under_test, list(zip([qX], [result_reference])), check_save_load=True) # Test from_float float_conv = torch.nn.Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros').float() float_conv.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_conv) float_conv(X.float()) quantized_float_conv = torch.nn.Sequential(float_conv) torch.quantization.convert(quantized_float_conv) # Smoke test to make sure the module actually runs quantized_float_conv(qX) # Check that bias is quantized based on output scale if use_bias: qbias = torch.quantize_linear( float_conv.bias, quantized_float_conv[0].scale / 2**16, 0, torch.qint32) self.assertEqual(quantized_float_conv[0].bias.dequantize(), qbias.dequantize()) # Smoke test extra_repr str(quantized_float_conv)
def test_conv_api(self, use_bias, use_fused): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) qX = torch.quantize_per_tensor(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) qw = torch.quantize_per_tensor(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None if use_fused: conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') # Run module with default-initialized parameters. # This tests that the constructor is correct. conv_under_test.set_weight_bias(qw, b) conv_under_test(qX) conv_under_test.scale = scale conv_under_test.zero_point = zero_point # Test members self.assertTrue(hasattr(conv_under_test, '_packed_params')) self.assertTrue(hasattr(conv_under_test, 'scale')) self.assertTrue(hasattr(conv_under_test, 'zero_point')) # Test properties self.assertEqual(qw, conv_under_test.weight()) self.assertEqual(b, conv_under_test.bias()) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) result_reference = qF.conv2d(qX, qw, bias=b, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, dtype=torch.quint8) if use_fused: # result_reference < zero_point doesn't work for qtensor yet # result_reference[result_reference < zero_point] = zero_point MB, OC, OH, OW = result_reference.size() for i in range(MB): for j in range(OC): for h in range(OH): for w in range(OW): if result_reference[i][j][h][w].int_repr( ) < zero_point: # assign 0. that gets converted to zero_point result_reference[i][j][h][w] = 0. self.assertEqual(result_reference, result_under_test, message="Tensors are not equal.") # Test serialization of quantized Conv Module using state_dict model_dict = conv_under_test.state_dict() self.assertEqual(model_dict['weight'], qw) if use_bias: self.assertEqual(model_dict['bias'], b) b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(loaded_dict[key], model_dict[key]) if use_fused: loaded_conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: loaded_conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') loaded_conv_under_test.load_state_dict(loaded_dict) self.assertEqual(loaded_conv_under_test._weight_bias(), conv_under_test._weight_bias()) if use_bias: self.assertEqual(loaded_conv_under_test.bias(), conv_under_test.bias()) self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale) self.assertEqual(loaded_conv_under_test.zero_point, conv_under_test.zero_point) self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test)) self.assertTrue(hasattr(conv_under_test, '_packed_params')) self.assertTrue(hasattr(loaded_conv_under_test, '_packed_params')) self.assertTrue(hasattr(conv_under_test, '_weight_bias')) self.assertTrue(hasattr(loaded_conv_under_test, '_weight_bias')) self.assertEqual(loaded_conv_under_test._weight_bias(), conv_under_test._weight_bias()) self.assertEqual(loaded_conv_under_test.weight(), qw) loaded_result = loaded_conv_under_test(qX) self.assertEqual(loaded_result, result_reference) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on save. # <start code> # b = io.BytesIO() # torch.save(conv_under_test, b) # b.seek(0) # loaded_conv = torch.load(b) # # self.assertEqual(conv_under_test.bias(), loaded_conv.bias()) # self.assertEqual(conv_under_test.scale, loaded_conv.scale) # self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point) # <end code> with self.assertRaisesRegex( RuntimeError, r'torch.save\(\) is not currently supported'): b = io.BytesIO() torch.save(conv_under_test, b) # JIT testing self.checkScriptable(conv_under_test, list(zip([qX], [result_reference])), check_save_load=True) # Test from_float float_conv = torch.nn.Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros').float() float_conv.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_conv, inplace=True) float_conv(X.float()) quantized_float_conv = torch.nn.Sequential(float_conv) torch.quantization.convert(quantized_float_conv, inplace=True) # Smoke test to make sure the module actually runs quantized_float_conv(qX) if use_bias: self.assertEqual(quantized_float_conv[0].bias(), float_conv.bias) # Smoke test extra_repr str(quantized_float_conv)
def test_conv_api(self, Q, padH, padW, sH, sW, dH, dW, prepacked): """Tests the correctness of the conv functional. The correctness is defined by the behavior being similar to the `quantized._ops` implementation. """ # Random iunputs X, (scale, zero_point), (qmin, qmax), torch_type = Q (inputs, filters, bias, groups) = X iC, oC = inputs.shape[1], filters.shape[0] iH, iW = inputs.shape[2:] kH, kW = filters.shape[2:] assume(kH // 2 >= padH) assume(kW // 2 >= padW) oH = _conv_output_shape(iH, kH, padH, sH, dH) assume(oH > 0) oW = _conv_output_shape(iW, kW, padW, sW, dW) assume(oW > 0) inputs = torch.from_numpy(inputs).to(torch.float) filters = torch.from_numpy(filters).to(torch.float) bias = torch.from_numpy(bias).to(torch.float) kernel_size = (kH, kW) stride = (sH, sW) i_padding = (padH, padW) dilation = (dH, dW) # Quantized inputs i_NHWC = inputs.permute([0, 2, 3, 1]).contiguous() w_RSCK = filters.permute([0, 2, 3, 1]).contiguous() q_inputs = torch.quantize_linear(i_NHWC, scale, zero_point, torch.quint8) q_filters = torch.quantize_linear(w_RSCK, scale, zero_point, torch.qint8) q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack( q_filters, groups) q_bias = torch.quantize_linear(bias, scale, zero_point, torch.qint32) # Reference op ref_op = torch.ops.quantized.fbgemm_conv2d # Results check try: ref_result = ref_op(q_inputs, q_filters_ref, q_bias, stride, i_padding, dilation, groups, scale, zero_point) except RuntimeError as e: e_msg = str(e).split("\n")[0].split("(")[0].strip() np.testing.assert_raises_regex(type(e), e_msg, qF.conv2d, q_inputs, q_filters_ref, bias=q_bias, scale=scale, zero_point=zero_point, stride=stride, padding=i_padding, dilation=dilation, groups=groups, prepacked=True, dtype=torch_type) else: if prepacked: q_filters = torch.ops.quantized.fbgemm_conv_prepack( q_filters, groups) q_result = qF.conv2d(q_inputs, q_filters, bias=q_bias, scale=scale, zero_point=zero_point, stride=stride, padding=i_padding, dilation=dilation, groups=groups, prepacked=prepacked, dtype=torch_type) np.testing.assert_equal(ref_result.int_repr().numpy(), q_result.int_repr().numpy())
def test_conv_api(self, use_bias, use_fused): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) X = X.permute([0, 2, 3, 1]).contiguous() qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None qb = torch.quantize_linear( b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) if use_bias else None if use_fused: conv_under_test = ConvReLU2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') else: conv_under_test = Conv2d(in_channels=iC, out_channels=oC, kernel_size=(kH, kW), stride=1, padding=0, dilation=1, groups=g, bias=use_bias, padding_mode='zeros') conv_under_test.weight = qw conv_under_test.bias = qb conv_under_test.scale = torch.tensor([scale], dtype=torch.double) conv_under_test.zero_point = torch.tensor([zero_point], dtype=torch.long) # Test members self.assertTrue(hasattr(conv_under_test, '_packed_weight')) self.assertTrue(hasattr(conv_under_test, 'scale')) self.assertTrue(hasattr(conv_under_test, 'zero_point')) # Test properties self.assertEqual(qw, conv_under_test.weight) self.assertEqual(qb, conv_under_test.bias) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) result_reference = qF.conv2d(qX, qw, bias=qb, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, dtype=torch.quint8) if use_fused: # result_reference < zero_point doesn't work for qtensor yet # result_reference[result_reference < zero_point] = zero_point MB, OC, OH, OW = result_reference.size() for i in range(MB): for j in range(OC): for h in range(OH): for w in range(OW): if result_reference[i][j][h][w].int_repr( ) < zero_point: # assign 0. that gets converted to zero_point result_reference[i][j][h][w] = 0. self.assertEqual(result_reference, result_under_test, message="Tensors are not equal.")
def test_conv_api(self, X, padH, padW, sH, sW, dH, dW): """Tests the correctness of the conv functional. The correctness is defined by the behavior being similar to the `quantized._ops` implementation. """ # Random inputs # X, (scale, zero_point, torch_type) = X (inputs, filters, bias, groups) = X inputs, (inputs_scale, inputs_zero_point, inputs_qtype) = inputs filters, (filters_scale, filters_zero_point, filters_qtype) = filters bias, (bias_scale, bias_zero_point, bias_qtype) = bias scale, zero_point = inputs_scale, inputs_zero_point torch_type = inputs_qtype iC, oC = inputs.shape[1], filters.shape[0] iH, iW = inputs.shape[2:] kH, kW = filters.shape[2:] assume(kH // 2 >= padH) assume(kW // 2 >= padW) oH = _conv_output_shape(iH, kH, padH, sH, dH) assume(oH > 0) oW = _conv_output_shape(iW, kW, padW, sW, dW) assume(oW > 0) inputs = torch.from_numpy(inputs).to(torch.float) filters = torch.from_numpy(filters).to(torch.float) bias = torch.from_numpy(bias).to(torch.float) kernel_size = (kH, kW) stride = (sH, sW) i_padding = (padH, padW) dilation = (dH, dW) # Quantized inputs q_inputs = torch.quantize_linear(inputs, inputs_scale, inputs_zero_point, inputs_qtype) q_filters = torch.quantize_linear(filters, filters_scale, filters_zero_point, filters_qtype) q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(q_filters.permute([0, 2, 3, 1]), groups) q_bias = torch.quantize_linear(bias, bias_scale, bias_zero_point, bias_qtype) # Reference op ref_op = torch.ops.quantized.fbgemm_conv2d # Results check try: ref_result = ref_op(q_inputs.permute([0, 2, 3, 1]), q_filters_ref, q_bias, stride, i_padding, dilation, groups, scale, zero_point).permute([0, 3, 1, 2]) except RuntimeError as e: e_msg = str(e).split("\n")[0].split("(")[0].strip() np.testing.assert_raises_regex( type(e), e_msg, qF.conv2d, q_inputs, q_filters, bias=q_bias, scale=scale, zero_point=zero_point, stride=stride, padding=i_padding, dilation=dilation, groups=groups, dtype=torch_type) else: q_result = qF.conv2d(q_inputs, q_filters, bias=q_bias, scale=scale, zero_point=zero_point, stride=stride, padding=i_padding, dilation=dilation, groups=groups, dtype=torch_type) self.assertEqual(ref_result, q_result)