def test_quantize_per_channel_sub_byte(self): """ Tests the per channel quantization scheme for 4-bit qtensors. The scale and zero point for this have to be in floating point. """ r = torch.rand(3, 2, dtype=torch.float) * 4 scales = torch.tensor([0.2, 0.3, 0.1], dtype=torch.float) zero_points = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float) qr = torch.quantize_per_channel(r, scales, zero_points, 0, torch.quint4x2) dequant_tensor = qr.dequantize() def _get_qranges(bit_width): if bit_width == 4: return 0, 15 def _quantize_per_channel_sub_byte_ref(data, scales, zero_points, axis, bit_width): dims = data.size() data = data.view(-1, dims[axis], np.prod(dims[axis + 1:])) qtensor_size = math.ceil(data.numel() / 2) res = torch.empty(qtensor_size, dtype=torch.uint8) elem_per_byte = 8 / bit_width quant_min, quant_max = _get_qranges(bit_width) for i in range(data.size()[0]): for j in range(data.size()[1]): for k in range(data.size()[2]): inv_scale = 1.0 / scales[j] index = i * data.size()[1] * data.size( )[2] + j * data.size()[2] + k qvalue = np.clip( np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max).to(dtype=torch.int) res_idx = int(index / elem_per_byte) if (index % elem_per_byte == 0): res[res_idx] = qvalue else: res[res_idx] |= (qvalue << ( (index % elem_per_byte) * bit_width)) return res ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 0, 4) self.assertTrue(np.allclose(qr.int_repr(), ref_res)) self.assertTrue( np.allclose(r.numpy(), dequant_tensor.numpy(), atol=1 / np.min(scales.numpy()))) # Check 4D tensor with non-zero axis. r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4 scales = torch.tensor([0.2, 0.03], dtype=torch.float) zero_points = torch.tensor([0.1, 0.2], dtype=torch.float) qr = torch.quantize_per_channel(r, scales, zero_points, axis=1, dtype=torch.quint4x2) ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 1, 4) self.assertTrue(np.allclose(qr.int_repr(), ref_res))
def test_qtensor_unsqueeze(self): x = torch.randn((1, 3, 4)) qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8) qy = qx.unsqueeze(2) self.assertEqual(qy.size(), (1, 3, 1, 4)) qy = qy.squeeze(2) self.assertEqual(qy.size(), qx.size()) # Per channel qtensor scales = torch.tensor([1.0]) zero_points = torch.tensor([0]) qx = torch.quantize_per_channel(x, scales=scales, zero_points=zero_points, dtype=torch.quint8, axis=0) qy = qx.unsqueeze(0) self.assertEqual(qy.size(), (1, 1, 3, 4)) self.assertEqual(qy.q_per_channel_axis(), 1) qz = qy.squeeze(0) self.assertEqual(qz.size(), x.size()) self.assertEqual(qz.q_per_channel_axis(), 0) with self.assertRaisesRegex(RuntimeError, "Squeeze is only possible on non-axis dimension for Per-Channel"): qz = qy.squeeze(1) # squeeze without dim specified x = torch.randn((3, 1, 2, 1, 4)) scales = torch.tensor([1.0, 1.0]) zero_points = torch.tensor([0, 0]) qx = torch.quantize_per_channel(x, scales=scales, zero_points=zero_points, dtype=torch.quint8, axis=2) qz = qx.squeeze() self.assertEqual(qz.size(), (3, 2, 4)) self.assertEqual(qz.q_per_channel_axis(), 1) with self.assertRaisesRegex(RuntimeError, "Squeeze is only possible on non-axis dimension for Per-Channel"): qz = qy.squeeze()
def _quantize_weight(weight: torch.Tensor, weight_qscheme: torch.qscheme, weight_dtype: torch.dtype, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, weight_axis: torch.Tensor): if weight_dtype == torch.float16: weight = weight.to(weight_dtype) return weight if weight_qscheme == torch.per_tensor_affine: if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype) return weight elif weight_qscheme in [ torch.per_channel_affine, torch.per_channel_affine_float_qparams ]: if weight_dtype in [ torch.quint8, torch.qint8, torch.quint4x2, torch.qint32 ]: weight = torch.quantize_per_channel( weight, weight_scale, weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type] return weight raise Exception( f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
def test_qtensor_quantize_per_channel(self): r = torch.rand(3, 2, dtype=torch.float) * 4 - 2 scales = torch.tensor([0.2, 0.03], dtype=torch.double) zero_points = torch.tensor([5, 10], dtype=torch.long) axis = 1 def quantize_c(data, scales, zero_points): res = torch.empty((3, 2)) quant_min, quant_max = 0, 255 for i in range(3): for j in range(2): res[i][j] = np.clip( np.round(data[i][j] / scales[j]) + zero_points[j], quant_min, quant_max) return res qr = torch.quantize_per_channel(r, scales, zero_points, axis, torch.quint8) rqr = qr.dequantize() self.assertTrue( np.allclose(qr.int_repr(), quantize_c(r, scales, zero_points))) self.assertTrue( np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy())))
def _test_quantize_per_channel(self, r, scales, zero_points, axis, float_params): def _quantize_per_channel_ref_nd(data, scales, zero_points, float_params): dims = data.size() data = data.view(-1, dims[axis], np.prod(dims[axis + 1:])) res = torch.empty_like(data) quant_min, quant_max = 0, 255 for i in range(res.size()[0]): for j in range(res.size()[1]): for k in range(res.size()[2]): if float_params: inv_scale = 1.0 / scales[j] res[i][j][k] = np.clip( np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max) else: res[i][j][k] = np.clip( np.round(data[i][j][k] / scales[j]) + zero_points[j], quant_min, quant_max) res = res.view(*dims) return res contig_format = torch.channels_last if r.ndim == 4 else torch.channels_last_3d for memory_format in [torch.contiguous_format, contig_format]: ref_res = _quantize_per_channel_ref_nd(r, scales, zero_points, float_params) r_contig = r.contiguous(memory_format=memory_format) qr = torch.quantize_per_channel(r_contig, scales, zero_points, axis, torch.quint8) rqr = qr.dequantize() self.assertTrue(np.allclose(qr.int_repr(), ref_res)) self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy())))
def _test_pickle_checkpoint_qtensor(self, device): with TemporaryFileName() as fname: class M(torch.jit.ScriptModule): __constants__ = ['fname'] def __init__(self): super(M, self).__init__() self.fname = fname @torch.jit.script_method def forward(self, x, y): torch.save((x, y), self.fname) return y q = torch.quantize_per_tensor(torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device) qc = torch.quantize_per_channel( torch.rand(2, 3, dtype=torch.float), scales=torch.tensor([0.1, 0.5, 0.01]), zero_points=torch.tensor([10, 0, 20]), axis=1, dtype=torch.quint8).to(device) m = M() m(q, qc) with open(fname, "rb") as handle: loaded_q, loaded_qc = torch.load(fname) self.assertEqual(loaded_q, q) self.assertEqual(loaded_qc, qc)
def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig): num_lengths = np.random.randint(1, 6) lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) num_indices = np.sum(lengths) indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)) obs = default_float_qparams_observer() obs(weights) qparams = obs.calculate_qparams() # Quantize the weights to 8bits qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) qemb.set_weight(qweight) qemb(indices) # Ensure the module has the correct weights self.assertEqual(qweight, qemb.weight()) w_packed = qemb._packed_params._packed_weight module_out = qemb(indices) # Call the qembedding operator directly ref = torch.ops.quantized.embedding_byte(w_packed, indices, sparse=False) self.assertEqual(module_out, ref) self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False, is_emb_bag=False)
def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig): r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8 """ num_lengths = np.random.randint(1, 6) lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) num_indices = np.sum(lengths) indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) offsets = lengths_to_offsets(lengths) # include the last offset offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0) weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)) obs = default_float_qparams_observer() obs(weights) # Get the scale and zero point for the weight tensor qparams = obs.calculate_qparams() # Quantize the weights to 8bits qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, mode='sum', _weight=qweight) qemb(indices, offsets) # Ensure the module has the correct weights self.assertEqual(qweight, qemb.weight()) w_packed = qemb._packed_params._packed_weight module_out = qemb(indices, offsets) # Call the qembedding_bag operator directly ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, per_sample_weights=None, include_last_offset=True) self.assertEqual(module_out, ref) self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True)
def test_conv_api(self, use_bias, per_channel): """Tests the correctness of the conv module. The correctness is defined against the functional implementation. """ N, iC, H, W = 10, 10, 10, 3 oC, g, kH, kW = 16, 1, 3, 3 scale, zero_point = 1.0 / 255, 128 stride = (1, 1) i_padding = (0, 0) dilation = (1, 1) X = torch.randn(N, iC, H, W, dtype=torch.float32) qX = torch.quantize_per_tensor(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) if per_channel: scale_tensor = torch.ones(oC, dtype=torch.double) zero_point_tensor = torch.zeros(oC, dtype=torch.long) for i in range(len(scale_tensor)): scale_tensor[i] = (i + 1.0) / 255.0 qw = torch.quantize_per_channel(w, scales=scale_tensor, zero_points=zero_point_tensor, axis=0, dtype=torch.qint8) else: qw = torch.quantize_per_tensor(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None q_filters_ref = torch.ops.quantized.conv_prepack( qw, b, stride, i_padding, dilation, g) ref_result = torch.ops.quantized.conv2d(qX, q_filters_ref, stride, i_padding, dilation, g, scale, zero_point) q_result = torch.nn.quantized.functional.conv2d(qX, qw, bias=b, scale=scale, zero_point=zero_point, stride=stride, padding=i_padding, dilation=dilation, groups=g, dtype=torch.quint8) self.assertEqual(ref_result, q_result)
def test_qparams_conversion(tensor, num_bits, distiller_mode, torch_dtype, per_channel, reduce_range): if reduce_range: if num_bits != 8: return True if quantization.is_linear_quant_mode_symmetric( distiller_mode) and torch_dtype == torch.quint8: return True # Calculate quantization parameters with Distiller for number of bits BEFORE reduce_range signed = distiller_mode != quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED distiller_scale, distiller_zp = _get_quant_params_from_tensor( tensor, num_bits, distiller_mode, per_channel=per_channel) # Convert parameters to PyTorch converted_scale, converted_zp = quantization.distiller_qparams_to_pytorch( distiller_scale, distiller_zp, num_bits, distiller_mode, torch_dtype, reduce_range) # Quantize tensor with Distiller # If reduce_range is set, then we actually quantize with num_bits-1 if reduce_range: num_bits -= 1 distiller_scale, distiller_zp = _get_quant_params_from_tensor( tensor, num_bits, distiller_mode, per_channel=per_channel) restrict = distiller_mode == quantization.LinearQuantMode.SYMMETRIC_RESTRICTED clamp_min, clamp_max = quantization.get_quantized_range( num_bits, signed=signed, signed_restrict_qrange=restrict) distiller_q_t = quantization.linear_quantize_clamp(tensor, distiller_scale, distiller_zp, clamp_min, clamp_max) # Quantize with PyTorch if per_channel: pytorch_q_t = torch.quantize_per_channel(tensor, converted_scale, converted_zp, 0, torch_dtype) else: pytorch_q_t = torch.quantize_per_tensor(tensor, converted_scale, converted_zp, torch_dtype) # Dequantize distiller_q_dq_t = quantization.linear_dequantize(distiller_q_t, distiller_scale, distiller_zp) pytorch_q_dq_t = pytorch_q_t.dequantize() # Compare - allow of up to one quantized "bin" between the tensors if per_channel: for idx, scale in enumerate(converted_scale): torch.testing.assert_allclose(distiller_q_dq_t[idx], pytorch_q_dq_t[idx], atol=scale, rtol=1e-05) else: torch.testing.assert_allclose(pytorch_q_dq_t, distiller_q_dq_t, atol=converted_scale, rtol=1e-05)
def _quantize_weight(float_wt, observer): wt_scale, wt_zp = observer.calculate_qparams() if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: qweight = torch.quantize_per_tensor( float_wt, float(wt_scale), int(wt_zp), torch.qint8) elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]: wt_axis = observer.ch_axis qweight = torch.quantize_per_channel( float_wt, wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, torch.qint8) elif observer.qscheme in [torch.per_channel_affine_float_qparams]: qweight = torch.quantize_per_channel( float_wt, wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, torch.quint8) else: raise ValueError("Unexpected qscheme " + observer.qscheme) return qweight
def __init__(self, per_channel): super(SimpleQTensor, self).__init__() x = torch.rand(5, 5).float() if not per_channel: x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8) else: s = torch.rand(5, dtype=torch.float64) + 0.1 zp = torch.randint(5, 15, (5, )) x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8) self.register_buffer('x', x_q)
def _make_conv_test_input( batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale, W_zero_point, use_bias, use_channelwise, ): in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( X_value_min, X_value_max, (batch_size, in_channels,) + input_feature_map_size) X = X_scale * (X_init - X_zero_point).float() X_q = torch.quantize_per_tensor( X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8) W_scale = W_scale * out_channels W_zero_point = W_zero_point * out_channels # Resize W_scale and W_zero_points arrays equal to out_channels W_scale = W_scale[:out_channels] W_zero_point = W_zero_point[:out_channels] # For testing, we use small values for weights and for activations so that # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in # qconv implementation and if there is no overflow. # In reference we can't exactly match the results with reference. # Please see the comment in qconv implementation file # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details. (W_value_min, W_value_max) = (-5, 5) # The operator expects them in the format # (out_channels, in_channels/groups,) + kernel_size W_init = torch.randint( W_value_min, W_value_max, (out_channels, in_channels_per_group,) + kernel_size) b_init = torch.randint(0, 10, (out_channels,)) if use_channelwise: W_shape = (-1, 1) + (1,) * len(kernel_size) W_scales_tensor = torch.tensor(W_scale, dtype=torch.float) W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float) W = W_scales_tensor.reshape(*W_shape) * ( W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() b = X_scale * W_scales_tensor * b_init.float() W_q = torch.quantize_per_channel( W, W_scales_tensor, W_zero_points_tensor.long(), 0, dtype=torch.qint8) else: W = W_scale[0] * (W_init - W_zero_point[0]).float() b = X_scale * W_scale[0] * b_init.float() W_q = torch.quantize_per_tensor( W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8) return (X, X_q, W, W_q, b if use_bias else None)
def test_qtensor_per_channel_load_save(self): r = torch.rand(20, 10, dtype=torch.float) * 4 - 2 scales = torch.rand(10) * 0.02 + 0.01 zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long) # quint32 is not supported yet for dtype in [torch.quint8, torch.qint8]: qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype) with tempfile.NamedTemporaryFile() as f: # Serializing and Deserializing Tensor torch.save(qr, f) f.seek(0) qr2 = torch.load(f) self.assertEqual(qr, qr2)
def _quantize_weight(float_wt, observer): wt_scale, wt_zp = observer.calculate_qparams() if observer.qscheme in [ torch.per_tensor_symmetric, torch.per_tensor_affine ]: qweight = torch.quantize_per_tensor(float_wt, float(wt_scale), int(wt_zp), torch.qint8) else: qweight = torch.quantize_per_channel(float_wt, wt_scale.to(torch.double), wt_zp.to(torch.int64), 0, torch.qint8) return qweight
def _test_numerical_consistency(self, test_type): r"""Comparing numerical consistency between quantize/dequantize op and the fake quantize op across devices and dtypes """ torch.random.manual_seed(NP_RANDOM_SEED) torch_types = [torch.qint8, torch.quint8] float_types = [torch.float, torch.float16, torch.float64] zero_types = [torch.long] devices = [torch.device('cpu'), torch.device('cuda') ] if torch.cuda.is_available() else [torch.device('cpu')] axis = 1 for i in range(20): for torch_type, float_type, device, zero_type in itertools.product( torch_types, float_types, devices, zero_types): X = torch.randn(3, 3, device=device).to(float_type) scales = (10 * torch.randn(3, device=device)).abs() scale = scales.mean().to(float).item() zeros = (10 * torch.randn(3, device=device)).abs().to( dtype=zero_type) zero = zeros.max().view(1).item() quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max test_was_run = False if test_type == "per_tensor": test_was_run = True Y = torch.dequantize( torch.quantize_per_tensor( X.to('cpu').to(torch.float), scale, zero, torch_type)).to(device).to(float_type) Y_prime = torch.fake_quantize_per_tensor_affine( X, scale, zero, quant_min, quant_max) self.assertEqual( Y, Y_prime, "Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor" ) if test_type == "per_channel": test_was_run = True Y = torch.dequantize( torch.quantize_per_channel( X.to('cpu').to(torch.float), scales.to('cpu'), zeros.to('cpu'), axis, torch_type)).to(device).to(float_type) Y_prime = torch.fake_quantize_per_channel_affine( X, scales, zeros, axis, quant_min, quant_max) self.assertEqual( Y, Y_prime, "Difference found between dequant+quant_per_channel and fake_quantize_per_channel" ) self.assertTrue(test_was_run)
def tensor_creation_ops(self): i = torch.tensor([[0, 1, 1], [2, 0, 2]]) v = torch.tensor([3, 4, 5], dtype=torch.float32) real = torch.tensor([1, 2], dtype=torch.float32) imag = torch.tensor([3, 4], dtype=torch.float32) inp = torch.tensor([-1.5, 0.0, 2.0]) values = torch.tensor([0.5]) quantized = torch.quantize_per_channel( torch.tensor([[-1.0, 0.0], [1.0, 2.0]]), torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8, ) return ( torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]), # torch.sparse_coo_tensor(i, v, [2, 3]), # not work for iOS torch.as_tensor([1, 2, 3]), torch.as_strided(torch.randn(3, 3), (2, 2), (1, 2)), torch.zeros(2, 3), torch.zeros((2, 3)), torch.zeros([2, 3], out=i), torch.zeros(5), torch.zeros_like(torch.empty(2, 3)), torch.ones(2, 3), torch.ones((2, 3)), torch.ones([2, 3]), torch.ones(5), torch.ones_like(torch.empty(2, 3)), torch.arange(5), torch.arange(1, 4), torch.arange(1, 2.5, 0.5), torch.range(1, 4), torch.range(1, 4, 0.5), torch.linspace(3.0, 3.0, steps=1), torch.logspace(start=2, end=2, steps=1, base=2.0), torch.eye(3), torch.empty(2, 3), torch.empty_like(torch.empty(2, 3), dtype=torch.int64), torch.empty_strided((2, 3), (1, 2)), torch.full((2, 3), 3.141592), torch.full_like(torch.full((2, 3), 3.141592), 2.71828), torch.quantize_per_tensor( torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8 ), torch.dequantize(quantized), torch.complex(real, imag), torch.polar(real, imag), torch.heaviside(inp, values), )
def test_numerical_consistency_per_channel(self, device, X): r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op """ np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, axis, torch_type) = X quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max X = to_tensor(X, device) scale = to_tensor(scale, device) zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device) # quantize_linear and dequantize are only implemented in CPU Y = torch.dequantize(torch.quantize_per_channel(X.cpu(), scale.cpu(), zero_point.cpu(), axis, torch_type)) Y_prime = torch.fake_quantize_per_channel_affine( X, scale, zero_point, axis, quant_min, quant_max) np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
def _quantize_weight(weight: torch.Tensor, weight_qscheme: torch.qscheme, weight_dtype: torch.dtype, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, weight_axis: torch.Tensor): if weight_qscheme == torch.per_tensor_affine: weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype) elif weight_qscheme in [ torch.per_channel_affine, torch.per_channel_affine_float_qparams ]: weight = torch.quantize_per_channel( weight, weight_scale, weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type] else: raise Exception(f"Unsupported qscheme: {weight_qscheme}") return weight
def _quantize_weight(float_wt, observer): if observer is None: # allow dummy observer that leads to as-is quantization return float_wt wt_scale, wt_zp = observer.calculate_qparams() if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: qweight = torch.quantize_per_tensor( float_wt, float(wt_scale), int(wt_zp), observer.dtype) elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]: wt_axis = observer.ch_axis qweight = torch.quantize_per_channel( float_wt, wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, observer.dtype) else: raise ValueError("Unexpected qscheme " + observer.qscheme) qweight = qweight.dequantize() return qweight
def test_quantize_per_channel_float_qparams(self): r = torch.rand(3, 2, dtype=torch.float) * 4 scales = torch.tensor([0.2, 0.03], dtype=torch.float) zero_points = torch.tensor([0.1, 0.2], dtype=torch.float) axis = 1 # Reference quantize function with FP zero_point. def quantize_ref(data, scales, zero_points): res = torch.empty((3, 2)) quant_min, quant_max = 0, 255 for i in range(3): for j in range(2): inv_scale = 1.0 / scales[j] res[i][j] = np.clip( np.round(data[i][j] * inv_scale + zero_points[j]), quant_min, quant_max) return res qr = torch.quantize_per_channel(r, scales, zero_points, axis, torch.quint8) dequant_tensor = qr.dequantize() ref = quantize_ref(r, scales, zero_points) self.assertTrue(np.allclose(qr.int_repr(), ref)) self.assertTrue(np.allclose(r.numpy(), dequant_tensor.numpy(), atol=1)) # Check 4D tensor with 2 different memory formats. r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4 scales = torch.tensor([0.2, 0.03], dtype=torch.float) zero_points = torch.tensor([0.1, 0.2], dtype=torch.float) self._test_quantize_per_channel(r, scales, zero_points, 1, True) scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float) zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float) self._test_quantize_per_channel(r, scales, zero_points, 0, True) # Check 5D tensor. r = torch.rand(3, 2, 4, 5, 7, dtype=torch.float) * 4 - 2 scales = torch.tensor([0.2, 0.03], dtype=torch.float) zero_points = torch.tensor([0.1, 0.2], dtype=torch.float) self._test_quantize_per_channel(r, scales, zero_points, 1, True) scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float) zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float) self._test_quantize_per_channel(r, scales, zero_points, 0, True)
def test_qtensor_per_channel_permute(self): r = torch.rand(20, 10, 2, 2, dtype=torch.float) * 4 - 2 scales = torch.rand(10) * 0.02 + 0.01 zero_points = torch.round(torch.rand(10) * 2 - 1).to(torch.long) qr = torch.quantize_per_channel(r, scales, zero_points, 1, torch.qint8) # we can't reorder the axis with self.assertRaises(RuntimeError): qr.transpose(0, 1) # but we can change memory format qlast = qr.contiguous(memory_format=torch.channels_last) self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride())))) self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride())))) self.assertEqual(qr.int_repr(), qlast.int_repr()) self.assertEqual(scales, qlast.q_per_channel_scales()) self.assertEqual(zero_points, qlast.q_per_channel_zero_points()) self.assertEqual(1, qlast.q_per_channel_axis()) self.assertEqual(qlast.dequantize(), qr.dequantize())
def minmax_symmetric_quantize(weight, min_vals, max_vals): """ Mimic pytorch's _ObserverBase.per_channel_symmetric quantization """ qmax = 127 qmin = -128 zero_points = torch.zeros(min_vals.size(), dtype=torch.int64) if torch.equal(max_vals, min_vals): scales = torch.ones(min_vals.size(), dtype=torch.float) else: max_vals = torch.max(-min_vals, max_vals) scales = max_vals / ((qmax - qmin) / 2) scales = torch.max(scales, torch.tensor([1e-8], device=scales.device, dtype=scales.dtype)) return torch.quantize_per_channel(weight.data.cpu(), scales.cpu(), zero_points, axis=0, dtype=torch.qint8)
def _quantize_and_dequantize_weight(weight: torch.Tensor, weight_qscheme: torch.qscheme, weight_dtype: torch.dtype, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, weight_axis: torch.Tensor): """ Quantize and then dequantize the weight based on the quantization parameters """ if weight_qscheme == torch.per_tensor_affine: weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype) weight_dequant = weight.dequantize() elif weight_qscheme == torch.per_channel_affine: weight = torch.quantize_per_channel( weight, weight_scale, weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type] weight_dequant = weight.dequantize() else: weight_dequant = weight return weight_dequant
def regular_serialization(): test_cases = {} for dtype, device in itertools.product(all_dtypes, all_devices): base_name = f'regular_serialization_{dtype_name(dtype)}_{device}' test_cases[f'{base_name}_0'] = [ make_tensor((3, 5), device=device, dtype=dtype, low=-9, high=9) ] a = make_tensor((15, 5, 5), device=device, dtype=dtype, low=-9, high=9) test_cases[f'{base_name}_1'] = [ get_storage(a), a.view((5, 3, 25)), a, a[1:], ] if dtype.is_floating_point or dtype.is_complex: m = torch.nn.Linear(50, 10, dtype=dtype, device=device) test_cases[f'{base_name}_module_0'] = [m] # Quantization if dtype == torch.float and device == 'cpu': for qdtype in [ torch.quint8, torch.qint8, torch.qint32, torch.quint4x2 ]: a = make_tensor((10, 3, 8, 2, 4), device=device, dtype=dtype, low=-9, high=9) q = torch.quantize_per_tensor(a, 1.0, 2, qdtype) test_cases[f'{base_name}_quant_0_{dtype_name(qdtype)}'] = [q] test_cases[f'{base_name}_quant_1_{dtype_name(qdtype)}'] = [ a, q ] # TODO: For some reason, qint32 throws an illegal instruction # error, for both master and local branch. Either I'm doing # something wrong or it's an actual problem. Either way, # I should file an issue if qdtype == torch.qint32: continue a = make_tensor((10, 3, 8, 2, 4), device=device, dtype=dtype, low=-9, high=9) scales = make_tensor((8, ), device=device, dtype=dtype, low=-9, high=9) zero_points = make_tensor((8, ), device=device, dtype=dtype, low=-9, high=9) q = torch.quantize_per_channel(a, scales, zero_points, 2, qdtype) test_cases[ f'{base_name}_quant_channel_0_{dtype_name(qdtype)}'] = [q] test_cases[ f'{base_name}_quant_channel_1_{dtype_name(qdtype)}'] = [ a, q ] # TODO: test sparse COO # TODO: test packaging return test_cases
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel, qengine): """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu""" if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return per_channel = False with override_quantized_engine(qengine): W = torch.rand(out_features, in_features).float() if per_channel: scale_tensor = torch.ones(out_features, dtype=torch.double) zero_point_tensor = torch.zeros(out_features, dtype=torch.long) for i in range(len(scale_tensor)): scale_tensor[i] = (i + 1.0) / 255.0 W_q = torch.quantize_per_channel(W, scales=scale_tensor, zero_points=zero_point_tensor, axis=0, dtype=torch.qint8) else: W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8) X = torch.rand(batch_size, in_features).float() X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8) B = torch.rand(out_features).float() if use_bias else None scale = 0.5 zero_point = 3 if use_fused: qlinear = nnq_fused.LinearReLU(in_features, out_features) else: qlinear = nnq.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X_q) qlinear.set_weight_bias(W_q, B) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q, atol=1e-5) W_pack = qlinear._packed_params._packed_params qlinear.scale = float(scale) qlinear.zero_point = int(zero_point) Z_q = qlinear(X_q) # Check if the module implementation matches calling the # ops directly if use_fused: Z_ref = torch.ops.quantized.linear_relu( X_q, W_pack, scale, zero_point) self.assertTrue('QuantizedLinearReLU' in str(qlinear)) else: Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) self.assertTrue('QuantizedLinear' in str(qlinear)) self.assertEqual(Z_ref, Z_q) # Test serialization of quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['_packed_params.weight'], W_q) if use_bias: self.assertEqual(model_dict['_packed_params.bias'], B) b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) if use_fused: loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features) else: loaded_qlinear = nnq.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.linear_unpack self.assertEqual( linear_unpack(qlinear._packed_params._packed_params), linear_unpack(loaded_qlinear._packed_params._packed_params)) if use_bias: self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_params')) self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) self.assertTrue(hasattr(qlinear, '_weight_bias')) self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) self.assertEqual( qlinear._weight_bias(), torch.ops.quantized.linear_unpack( qlinear._packed_params._packed_params)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) # The below check is meant to ensure that `torch.save` and `torch.load` # serialization works, however it is currently broken by the following: # https://github.com/pytorch/pytorch/issues/24045 # # Instead, we currently check that the proper exception is thrown on save. # <start code> # b = io.BytesIO() # torch.save(qlinear, b) # b.seek(0) # loaded = torch.load(b) # self.assertEqual(qlinear.weight(), loaded.weight()) # self.assertEqual(qlinear.scale, loaded.scale) # self.assertEqual(qlinear.zero_point, loaded.zero_point) # <end code> with self.assertRaisesRegex( RuntimeError, r'torch.save\(\) is not currently supported'): b = io.BytesIO() torch.save(qlinear, b) # Test JIT self.checkScriptable(qlinear, list(zip([X_q], [Z_ref])), check_save_load=True) # Test from_float. float_linear = torch.nn.Linear(in_features, out_features).float() float_linear.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_linear, inplace=True) float_linear(X.float()) # Sequential allows swapping using "convert". quantized_float_linear = torch.nn.Sequential(float_linear) quantized_float_linear = torch.quantization.convert( quantized_float_linear, inplace=True) # Smoke test to make sure the module actually runs quantized_float_linear(X_q) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel): """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu""" if torch.backends.quantized.engine == 'qnnpack': per_channel = False W = torch.rand(out_features, in_features).float() if per_channel: scale_tensor = torch.ones(out_features, dtype=torch.double) zero_point_tensor = torch.zeros(out_features, dtype=torch.long) for i in range(len(scale_tensor)): scale_tensor[i] = (i + 1.0) / 255.0 W_q = torch.quantize_per_channel(W, scales=scale_tensor, zero_points=zero_point_tensor, axis=0, dtype=torch.qint8) else: W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8) X = torch.rand(batch_size, in_features).float() X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8) B = torch.rand(out_features).float() if use_bias else None scale = 0.5 zero_point = 3 if use_fused: qlinear = nnq_fused.LinearReLU(in_features, out_features) else: qlinear = nnq.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X_q) qlinear.set_weight_bias(W_q, B) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0) W_pack = qlinear._packed_params._packed_params qlinear.scale = float(scale) qlinear.zero_point = int(zero_point) Z_q = qlinear(X_q) # Check if the module implementation matches calling the # ops directly if use_fused: Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) self.assertTrue('QuantizedLinearReLU' in str(qlinear)) else: Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) self.assertTrue('QuantizedLinear' in str(qlinear)) self.assertEqual(Z_ref, Z_q) # Test serialization of quantized Linear Module using state_dict model_dict = qlinear.state_dict() b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: if isinstance(model_dict[key], torch._C.ScriptObject): assert isinstance(loaded_dict[key], torch._C.ScriptObject) w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key]) w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key]) self.assertEqual(w_model, w_loaded) self.assertEqual(b_model, b_loaded) else: self.assertEqual(model_dict[key], loaded_dict[key]) if use_fused: loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features) else: loaded_qlinear = nnq.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.linear_unpack self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), linear_unpack(loaded_qlinear._packed_params._packed_params)) if use_bias: self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_params')) self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) self.assertTrue(hasattr(qlinear, '_weight_bias')) self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) b = io.BytesIO() torch.save(qlinear, b) b.seek(0) loaded = torch.load(b) self.assertEqual(qlinear.weight(), loaded.weight()) self.assertEqual(qlinear.scale, loaded.scale) self.assertEqual(qlinear.zero_point, loaded.zero_point) # Test JIT self.checkScriptable(qlinear, [[X_q]], check_save_load=True) # Test from_float. float_linear = torch.nn.Linear(in_features, out_features).float() float_linear.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_linear, inplace=True) float_linear(X.float()) # Sequential allows swapping using "convert". quantized_float_linear = torch.nn.Sequential(float_linear) quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True) # Smoke test to make sure the module actually runs quantized_float_linear(X_q) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel): if torch.backends.quantized.engine == 'qnnpack': per_channel = False # use_fused -> quantized class class_map = { True: nniq.LinearReLU, False: nnq.Linear, } W = torch.rand(out_features, in_features).float() if per_channel: scale_tensor = torch.ones(out_features, dtype=torch.double) zero_point_tensor = torch.zeros(out_features, dtype=torch.long) for i in range(len(scale_tensor)): scale_tensor[i] = (i + 1.0) / 255.0 W_q = torch.quantize_per_channel(W, scales=scale_tensor, zero_points=zero_point_tensor, axis=0, dtype=torch.qint8) else: W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8) X = torch.rand(batch_size, in_features).float() X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8) B = torch.rand(out_features).float() if use_bias else None scale = 0.5 zero_point = 3 qlinear = class_map[use_fused](in_features, out_features) qlinear_copy = copy.deepcopy(qlinear) self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X_q) qlinear.set_weight_bias(W_q, B) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0) # testing packed param implementation qlinear.scale = float(scale) qlinear.zero_point = int(zero_point) Z_q = qlinear(X_q) # Check if the module implementation matches calling the # ops directly W_pack = qlinear._packed_params._packed_params if use_fused: Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) else: Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) self.assertEqual(Z_ref, Z_q) self.assertTrue( ("QuantizedLinearReLU" if use_fused else "QuantizedLinear") in str(qlinear)) # Test serialization of quantized Linear Module using state_dict model_dict = qlinear.state_dict() b = io.BytesIO() torch.save(model_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in model_dict: if isinstance(model_dict[key], torch._C.ScriptObject): assert isinstance(loaded_dict[key], torch._C.ScriptObject) w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key]) w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key]) self.assertEqual(w_model, w_loaded) self.assertEqual(b_model, b_loaded) else: self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qlinear = class_map[use_fused]( in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.linear_unpack self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), linear_unpack(loaded_qlinear._packed_params._packed_params)) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) # scripting will add __overloads__ to __dict__, which is why we script a copy # to be able to do the check in the next line self.checkScriptable(copy.deepcopy(loaded_qlinear), [[X_q]], check_save_load=True) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) # Test serialization b = io.BytesIO() torch.save(qlinear, b) b.seek(0) loaded = torch.load(b) self.assertEqual(qlinear.weight(), loaded.weight()) self.assertEqual(qlinear.scale, loaded.scale) self.assertEqual(qlinear.zero_point, loaded.zero_point) # Test copy and deepcopy copied_linear = copy.copy(qlinear) self.assertEqual(copied_linear.bias(), qlinear.bias()) self.assertEqual(copied_linear.scale, qlinear.scale) self.assertEqual(copied_linear.zero_point, qlinear.zero_point) Y_copied = copied_linear(X_q) np.testing.assert_array_almost_equal( Z_q.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0) deepcopied_linear = copy.deepcopy(qlinear) self.assertEqual(deepcopied_linear.bias(), qlinear.bias()) self.assertEqual(deepcopied_linear.scale, qlinear.scale) self.assertEqual(deepcopied_linear.zero_point, qlinear.zero_point) Y_deepcopied = copied_linear(X_q) np.testing.assert_array_almost_equal( Z_q.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0) # Test JIT self.checkScriptable(qlinear, [[X_q]], check_save_load=True) # Make sure `from_float` works for all linear variants modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] for mut in modules_under_test: # Test from_float. float_linear = mut(in_features, out_features).float() float_linear.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_linear, inplace=True) float_linear(X.float()) # Sequential allows swapping using "convert". quantized_float_linear = torch.nn.Sequential(float_linear) quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True) # Smoke test to make sure the module actually runs quantized_float_linear(X_q) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
def test_serialize_graph(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) self.e = torch.rand(4) self.conv = torch.nn.Conv2d(3, 3, 2, bias=False) def forward(self, a, b, c): add_1 = a + b conv1 = self.conv(c) linear = self.linear(add_1 + conv1) add_2 = linear + self.e return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) c = torch.rand(3, 3, 2, 2) graph_manipulation.get_size_of_all_nodes(traced, [a, b, c]) partitioner = Partitioner() devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)] partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules # Fix for now to add type/shape to output for node in traced.graph.nodes: if node.op == "output": node.meta['tensor_meta'] = extract_tensor_metadata(a) for mod in module_with_submodules.modules(): if isinstance(mod, GraphModule): for node in mod.graph.nodes: node.meta['tensor_meta'] = extract_tensor_metadata(a) for node in module_with_submodules.graph.nodes: node.meta['tensor_meta'] = extract_tensor_metadata(a) weights1 = {} weights2 = {} serialized_graph1 = graph_manipulation.serialize_module( traced, weights1) serialized_graph2 = graph_manipulation.serialize_module( module_with_submodules, weights2) assert len(weights1) == 4 assert len(weights2) == 4 assert len(serialized_graph1["nodes"]) == 10 assert len(serialized_graph1["weights"]) == 4 assert len(serialized_graph1["modules"]) == 0 assert len(serialized_graph2["nodes"]) == 6 assert len(serialized_graph2["weights"]) == 4 assert len(serialized_graph2["modules"]) == 1 assert serialized_graph1["weights"]["linear.weight"][ "shape"] == "[4, 4]" assert (serialized_graph1["weights"]["linear.weight"]["dtype"] == "torch.float32") assert (serialized_graph1["weights"]["linear.weight"]["is_quantized"] is False) assert serialized_graph1["nodes"][0]["shape"] == "[4]" assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32" assert serialized_graph1["nodes"][0]["target"] == "a" assert serialized_graph1["nodes"][0]["op_code"] == "placeholder" assert serialized_graph1["nodes"][0]["name"] == "a" assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_1" assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True # Test quantization info serialization. x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32) q_tensor_channel = torch.quantize_per_channel( x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) result = graph_manipulation.serialize_tensor_quantization(q_tensor) result2 = graph_manipulation.serialize_tensor_quantization( q_tensor_channel) assert result["qscheme"] == "torch.per_tensor_affine" assert result["q_scale"] == 1.0 assert result2["qscheme"] == "torch.per_channel_affine" assert len(result2["q_per_channel_scales"]) == 2
dtype=torch.int64)) # E: {Tensor} reveal_type(torch.empty_strided((2, 3), (1, 2))) # E: {Tensor} # torch.full/full_like reveal_type(torch.full((2, 3), 3.141592)) # E: {Tensor} reveal_type(torch.full_like(torch.full((2, 3), 3.141592), 2.71828)) # E: {Tensor} # torch.quantize_per_tensor reveal_type( torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8)) # E: {Tensor} # torch.quantize_per_channel x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) quant = torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) reveal_type(x) # E: {Tensor} # torch.dequantize reveal_type(torch.dequantize(x)) # E: {Tensor} # torch.complex real = torch.tensor([1, 2], dtype=torch.float32) imag = torch.tensor([3, 4], dtype=torch.float32) reveal_type(torch.complex(real, imag)) # E: {Tensor} # torch.polar abs = torch.tensor([1, 2], dtype=torch.float64) pi = torch.acos(torch.zeros(1)).item() * 2 angle = torch.tensor([pi / 2, 5 * pi / 4], dtype=torch.float64) reveal_type(torch.polar(abs, angle)) # E: {Tensor}