def forward(self, input): """Do an inference on Identity.""" input = input.cpu() input = torch.quantize_per_tensor(input, 1.0, 0, self._quant_type[self.quant_bit]) output = super().forward(input) if vega.is_npu_device(): output = torch.dequantize(output).npu() elif vega.is_gpu_device(): output = torch.dequantize(output).cuda() else: output = torch.dequantize(output) return output
def _test_numerical_consistency(self, test_type): r"""Comparing numerical consistency between quantize/dequantize op and the fake quantize op across devices and dtypes """ torch.random.manual_seed(NP_RANDOM_SEED) torch_types = [torch.qint8, torch.quint8] float_types = [torch.float, torch.float16, torch.float64] zero_types = [torch.long] devices = [torch.device('cpu'), torch.device('cuda') ] if torch.cuda.is_available() else [torch.device('cpu')] axis = 1 for i in range(20): for torch_type, float_type, device, zero_type in itertools.product( torch_types, float_types, devices, zero_types): X = torch.randn(3, 3, device=device).to(float_type) scales = (10 * torch.randn(3, device=device)).abs() scale = scales.mean().to(float).item() zeros = (10 * torch.randn(3, device=device)).abs().to( dtype=zero_type) zero = zeros.max().view(1).item() quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max test_was_run = False if test_type == "per_tensor": test_was_run = True Y = torch.dequantize( torch.quantize_per_tensor( X.to('cpu').to(torch.float), scale, zero, torch_type)).to(device).to(float_type) Y_prime = torch.fake_quantize_per_tensor_affine( X, scale, zero, quant_min, quant_max) self.assertEqual( Y, Y_prime, "Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor" ) if test_type == "per_channel": test_was_run = True Y = torch.dequantize( torch.quantize_per_channel( X.to('cpu').to(torch.float), scales.to('cpu'), zeros.to('cpu'), axis, torch_type)).to(device).to(float_type) Y_prime = torch.fake_quantize_per_channel_affine( X, scales, zeros, axis, quant_min, quant_max) self.assertEqual( Y, Y_prime, "Difference found between dequant+quant_per_channel and fake_quantize_per_channel" ) self.assertTrue(test_was_run)
def test_ptq_quantize_first(self): """The expectation is post_training_sparse_quantize function 1. Takes in a model 2. Quantize the embeddings 3. Sparsifies the embeddings This unit test checks that 1. Embeddings and EmbeddingBags are sparsified to the right sparsity levels 2. Embeddings and EmbeddingBags are quantized 3. Linear modules are not quanitzed """ model = Model() sparse_config = {'sparsity_level': 0.8, 'sparse_block_shape': (1, 1)} post_training_sparse_quantize(model, DataNormSparsifier, sparsify_first=False, **sparse_config) assert type( model.emb1) == torch.nn.quantized.modules.embedding_ops.Embedding assert type(model.embbag1 ) == torch.nn.quantized.modules.embedding_ops.EmbeddingBag assert type(model.emb_seq[0] == torch.nn.quantized.modules.embedding_ops.Embedding) assert type(model.emb_seq[1] == torch.nn.quantized.modules.embedding_ops.EmbeddingBag) assert type(model.linear1) == nn.Linear # not quantized assert type(model.linear2) == nn.Linear # not quantized dequant_emb1 = torch.dequantize(model.emb1.weight()) dequant_embbag1 = torch.dequantize(model.embbag1.weight()) dequant_emb_seq_0 = torch.dequantize(model.emb_seq[0].weight()) dequant_emb_seq_1 = torch.dequantize(model.emb_seq[1].weight()) # higher threshold as quantization occurs before sparsity threshold = 1 # zero points seem to have higher magnitude with sparsity occuring after sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean() sl_embbag1 = (torch.abs(dequant_embbag1) < threshold).float().mean() sl_emb_seq_0 = (torch.abs(dequant_emb_seq_0) < threshold).float().mean() sl_emb_seq_1 = (torch.abs(dequant_emb_seq_1) < threshold).float().mean() assert abs(sl_emb1 - 0.80) <= 0.05 # +- 5% leeway assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway assert abs(sl_emb_seq_0 - 0.80) <= 0.05 # +- 5% leeway assert abs(sl_emb_seq_1 - 0.80) <= 0.05 # +- 5% leeway
def test_numerical_consistency_cuda(self): ''' Comparing numerical consistency between CPU quantize/dequantize op and the CUDA fake quantize op ''' np.random.seed(NP_RANDOM_SEED) fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward scale = 3 zero_point = 2 num_bits = 8 X = np.random.rand(20, 20) * 125 X_torch = torch.from_numpy(X).float() Y = torch.dequantize( torch.quantize_linear(X_torch, scale, zero_point, torch.qint8)) Y_prime = fake_quantize_per_tensor_affine_forward( X=X_torch.cuda(), scale=scale, zero_point=zero_point, num_bits=num_bits, quant_delay=0, iter=0) tolerance = 1e-6 np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
def forward(self, x): quant_weight = torch.quantize_per_tensor( self.weight, 0.5, 3, torch.quint8 ) dequant_weight = torch.dequantize(quant_weight) output = torch.nn.functional.linear(x, dequant_weight, self.bias) return self.relu(output)
def forward(self, input): """Do an inference on Identity.""" input = input.cpu() input = torch.quantize_per_tensor(input, 1.0, 0, self._quant_type[self.quant_bit]) output = super().forward(input) output = torch.dequantize(output).cuda() return output
def forward(self, x): x = self.quant(x) x = torch.ops.quantized.layer_norm( x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=1e-05, output_scale=self.scale, output_zero_point=self.zero_point, ) return torch.dequantize(x)
def process(self, data: Any) -> Any: """ Implements a Processor for applying dequantization to MistNet PyTorch features. """ feature_dataset = [] for logit, target in data: feature_dataset.append((torch.dequantize(logit), target)) logging.info("[Server #%d] Dequantized features.", self.server_id) return feature_dataset
def test_ptq_sparsify_first(self): """The expectation is post_training_sparse_quantize function 1. Takes in a model 2. Sparsifies the embeddings 3. Quantize the embeddings This unit test checks that 1. Embeddings and EmbeddingBags are sparsified to the right sparsity levels 2. Embeddings and EmbeddingBags are quantized 3. Linear modules are not quanitzed """ model = Model() sparse_config = {'sparsity_level': 0.80, 'sparse_block_shape': (1, 1)} select_embeddings = [model.embbag1, model.emb1] post_training_sparse_quantize(model, data_sparsifier_class=DataNormSparsifier, sparsify_first=True, select_embeddings=select_embeddings, **sparse_config) assert type( model.emb1) == torch.nn.quantized.modules.embedding_ops.Embedding assert type(model.embbag1 ) == torch.nn.quantized.modules.embedding_ops.EmbeddingBag assert type(model.emb_seq[0] == nn.Embedding) assert type(model.emb_seq[1] == nn.EmbeddingBag) assert type(model.linear1) == nn.Linear assert type(model.linear2) == nn.Linear dequant_emb1 = torch.dequantize(model.emb1.weight()) dequant_embbag1 = torch.dequantize(model.embbag1.weight()) threshold = 1e-2 sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean() sl_embbag1 = (torch.abs(dequant_embbag1) < threshold).float().mean() assert abs(sl_emb1 - 0.80) <= 0.05 # +- 5% leeway assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway
def forward(self, *x): o = [] for i, t in enumerate(x): if i == 0: t = torch.quantize_per_tensor(t, 1, 0, torch.qint32) t = torch.dequantize(t) o.append(t) x = self.module(*o) if isinstance(x, torch.Tensor): x = torch.quantize_per_tensor(x, 2, 0, torch.qint32) x = torch.dequantize(x) else: o = [] for i, t in enumerate(x): if i == 0: t = torch.quantize_per_tensor(t, 1, 0, torch.qint32) t = torch.dequantize(t) o.append(t) x = tuple(o) return x
def test_numerical_consistency_per_tensor(self, device, X): r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op """ np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, torch_type) = X quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max X = to_tensor(X, device) # quantize_per_tensor and dequantize are only implemented in CPU Y = torch.dequantize(torch.quantize_per_tensor(X.cpu(), scale, zero_point, torch_type)) Y_prime = torch.fake_quantize_per_tensor_affine( X, scale, zero_point, quant_min, quant_max) np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
def tensor_creation_ops(self): i = torch.tensor([[0, 1, 1], [2, 0, 2]]) v = torch.tensor([3, 4, 5], dtype=torch.float32) real = torch.tensor([1, 2], dtype=torch.float32) imag = torch.tensor([3, 4], dtype=torch.float32) inp = torch.tensor([-1.5, 0.0, 2.0]) values = torch.tensor([0.5]) quantized = torch.quantize_per_channel( torch.tensor([[-1.0, 0.0], [1.0, 2.0]]), torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8, ) return ( torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]), # torch.sparse_coo_tensor(i, v, [2, 3]), # not work for iOS torch.as_tensor([1, 2, 3]), torch.as_strided(torch.randn(3, 3), (2, 2), (1, 2)), torch.zeros(2, 3), torch.zeros((2, 3)), torch.zeros([2, 3], out=i), torch.zeros(5), torch.zeros_like(torch.empty(2, 3)), torch.ones(2, 3), torch.ones((2, 3)), torch.ones([2, 3]), torch.ones(5), torch.ones_like(torch.empty(2, 3)), torch.arange(5), torch.arange(1, 4), torch.arange(1, 2.5, 0.5), torch.range(1, 4), torch.range(1, 4, 0.5), torch.linspace(3.0, 3.0, steps=1), torch.logspace(start=2, end=2, steps=1, base=2.0), torch.eye(3), torch.empty(2, 3), torch.empty_like(torch.empty(2, 3), dtype=torch.int64), torch.empty_strided((2, 3), (1, 2)), torch.full((2, 3), 3.141592), torch.full_like(torch.full((2, 3), 3.141592), 2.71828), torch.quantize_per_tensor( torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8 ), torch.dequantize(quantized), torch.complex(real, imag), torch.polar(real, imag), torch.heaviside(inp, values), )
def test_tensor_dump_and_set(self): model = copy.deepcopy(self.lpot_model) model.model.eval().fuse_model() quantizer = Quantization('ptq_yaml.yaml') dataset = quantizer.dataset('dummy', (100, 3, 256, 256), label=True) dataloader = common.DataLoader(dataset) dataloader = common._generate_common_dataloader(dataloader, 'pytorch') quantizer.eval_dataloader = dataloader quantizer.calib_dataloader = dataloader quantizer.model = common.Model(model.model) q_model = quantizer() quantizer.strategy.adaptor.inspect_tensor( model, dataloader, op_list=['conv1.0', 'layer1.0.conv1.0'], iteration_list=[1, 2], inspect_type='all', save_to_disk=True) load_array = lambda *a, **k: np.load(*a, allow_pickle=True, **k) a = load_array('saved/dump_tensor/activation_iter1.npz') w = load_array('saved/dump_tensor/weight.npz') version = get_torch_version() if version >= '1.8': self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] == a['conv1.0'].item()['conv1.0.output0'].shape[1]) else: self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] == a['conv1.0'].item()['conv1.1.output0'].shape[1]) data = np.random.random( w['conv1.0'].item()['conv1.0.weight'].shape).astype(np.float32) quantizer.strategy.adaptor.set_tensor(q_model, {'conv1.0.weight': data}) changed_tensor = q_model.get_weight('conv1.weight') scales = changed_tensor.q_per_channel_scales() changed_tensor_fp32 = torch.dequantize(changed_tensor) self.assertTrue( np.allclose(data, changed_tensor_fp32.numpy(), atol=2 / np.min(scales.numpy()))) quantizer.strategy.adaptor.inspect_tensor( q_model, dataloader, op_list=['conv1.0', 'layer1.0.conv1.0'], iteration_list=[1, 2], inspect_type='all', save_to_disk=False)
def test_numerical_consistency(self): ''' Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op ''' np.random.seed(NP_RANDOM_SEED) fake_quantize_per_tensor_affine_forward = torch.ops.quantized.fake_quantize_per_tensor_affine_forward scale = 3 zero_point = 2 quant_min, quant_max = 0, 255 X = np.random.rand(20, 20) * 125 X_torch = torch.from_numpy(X).float() Y = torch.dequantize( torch.quantize_linear(X_torch, scale, zero_point, torch.qint8)) Y_prime = _intrinsic.fq_per_tensor_affine_forward( X=X_torch, scale=scale, zero_point=zero_point, quant_min=quant_min, quant_max=quant_max, quant_delay=0, iter=0) tolerance = 1e-6 np.testing.assert_allclose(Y, Y_prime, rtol=tolerance, atol=tolerance)
def _process_layer(self, layer: torch.Tensor) -> torch.Tensor: layer = torch.dequantize(layer) return layer
def dequantize(*, input, input_tensor_meta): """ `input_tensor_meta` contains extra argument of quantization parameters, e.g. scale/zero_point and will be using for lowring dequantize op to TensorRT """ return torch.dequantize(input)
# torch.full/full_like torch.full((2, 3), 3.141592) torch.full_like(torch.full((2, 3), 3.141592), 2.71828) # torch.quantize_per_tensor torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) # torch.quantize_per_channel x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) quant = torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) # torch.dequantize torch.dequantize(x) # torch.complex real = torch.tensor([1, 2], dtype=torch.float32) imag = torch.tensor([3, 4], dtype=torch.float32) torch.complex(real, imag) # torch.polar abs = torch.tensor([1, 2], dtype=torch.float64) pi = torch.acos(torch.zeros(1)).item() * 2 angle = torch.tensor([pi / 2, 5 * pi / 4], dtype=torch.float64) torch.polar(abs, angle) # torch.heaviside inp = torch.tensor([-1.5, 0, 2.0]) values = torch.tensor([0.5])
reveal_type(torch.full_like(torch.full((2, 3), 3.141592), 2.71828)) # E: {Tensor} # torch.quantize_per_tensor reveal_type( torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8)) # E: {Tensor} # torch.quantize_per_channel x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) quant = torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) reveal_type(x) # E: {Tensor} # torch.dequantize reveal_type(torch.dequantize(x)) # E: {Tensor} # torch.complex real = torch.tensor([1, 2], dtype=torch.float32) imag = torch.tensor([3, 4], dtype=torch.float32) reveal_type(torch.complex(real, imag)) # E: {Tensor} # torch.polar abs = torch.tensor([1, 2], dtype=torch.float64) pi = torch.acos(torch.zeros(1)).item() * 2 angle = torch.tensor([pi / 2, 5 * pi / 4], dtype=torch.float64) reveal_type(torch.polar(abs, angle)) # E: {Tensor} # torch.heaviside inp = torch.tensor([-1.5, 0, 2.0]) values = torch.tensor([0.5])
reveal_type(torch.full_like(torch.full((2, 3), 3.141592), 2.71828)) # E: torch.tensor.Tensor # torch.quantize_per_tensor reveal_type( torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8)) # E: torch.tensor.Tensor # torch.quantize_per_channel x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) quant = torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) reveal_type(x) # E: torch.tensor.Tensor # torch.dequantize reveal_type(torch.dequantize(x)) # E: torch.tensor.Tensor # torch.complex real = torch.tensor([1, 2], dtype=torch.float32) imag = torch.tensor([3, 4], dtype=torch.float32) reveal_type(torch.complex(real, imag)) # E: torch.tensor.Tensor # torch.polar abs = torch.tensor([1, 2], dtype=torch.float64) pi = torch.acos(torch.zeros(1)).item() * 2 angle = torch.tensor([pi / 2, 5 * pi / 4], dtype=torch.float64) reveal_type(torch.polar(abs, angle)) # E: torch.tensor.Tensor # torch.heaviside inp = torch.tensor([-1.5, 0, 2.0]) values = torch.tensor([0.5])
def test_qlinear_packed_params(self, allow_non_zero_zero_points=False): # copied from https://pytorch.org/docs/stable/sparse.html#csr-tensor-operations, # so row/col block indices match that example, but with blocks and # scaled rows weight_fp32 = torch.Tensor([ [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0], [6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ]) row_block_size = 1 col_block_size = 4 out_features = weight_fp32.shape[0] in_features = weight_fp32.shape[1] scales = [2.0, 6.0, 12.0] zero_points = [((i + 1) if allow_non_zero_zero_points else 0) for i in range(out_features)] dtype = torch.qint8 wide_weight_fp32 = torch.zeros( (3, 4008)) # 4000 is tile width for Fbgemm wide_weight_fp32[0][0] = 4 wide_weight_fp32[0][4004] = 6 wide_weight_fp32[1][0] = 8 per_tensor_small = ( torch.quantize_per_tensor(weight_fp32, scales[0], zero_points[0], dtype), True, [0, 1, 3, 3], [2, 0, 1], [ x + (1 if allow_non_zero_zero_points else 0) for x in [1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6] ], ) per_channel_small = ( torch.quantize_per_channel( weight_fp32, torch.Tensor(scales), torch.Tensor(zero_points).to(torch.int), 0, # axis = 0 dtype, ), False, [0, 1, 3, 3], [2, 0, 1], [ x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0) for (i, x) in enumerate([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2]) ], ) per_tensor_large = ( torch.quantize_per_tensor( wide_weight_fp32, scales[0], zero_points[0], dtype, ), True, [0, 2, 3, 3], [0, 1001, 0], [ x + (1 if allow_non_zero_zero_points else 0) for x in [2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0] ], ) for (weight, is_per_tensor_quantized, expected_row_block_indices, expected_col_block_indices, expected_weights) in [ per_tensor_small, per_channel_small, per_tensor_large ]: lin = Linear( out_features=weight.shape[0], in_features=weight.shape[1], row_block_size=row_block_size, col_block_size=col_block_size, bias=True, dtype=dtype, ) bias = torch.ones(size=(weight.shape[0], )) lin.set_weight_bias(weight, bias, row_block_size, col_block_size) serialized = lin._packed_params._packed_params.__getstate__() ( _, # version bias_, out_features_block_size_, in_features_block_size_, weight_scales_, weight_zero_points_, quantization_scheme_, row_block_indices_, col_block_indices_, weights_, output_channels_, input_channels_) = serialized[0] # Test Serialization self.assertEqual(bias_, bias) self.assertEqual(out_features_block_size_, row_block_size) self.assertEqual(in_features_block_size_, col_block_size) self.assertEqual( weight_scales_, [scales[0]] if is_per_tensor_quantized else scales) self.assertEqual( weight_zero_points_, [zero_points[0]] if is_per_tensor_quantized else zero_points) self.assertEqual(quantization_scheme_, is_per_tensor_quantized) self.assertEqual(row_block_indices_, expected_row_block_indices) self.assertEqual(col_block_indices_, expected_col_block_indices) self.assertEqual( weights_.tolist(), [v + 128 for v in expected_weights]) # weights are serialized as +128 self.assertEqual(output_channels_, weight.shape[0]) self.assertEqual(input_channels_, weight.shape[1]) # Test Unpacking (weights_, bias_, out_features_block_size_, in_features_block_size_) = lin._weight_bias() self.assertEqual(torch.dequantize(weights_), torch.dequantize(weight)) self.assertEqual(bias_, bias) self.assertEqual(out_features_block_size_, row_block_size) self.assertEqual(in_features_block_size_, col_block_size) # Test Deserialization with tempfile.TemporaryFile() as file_buff: torch.save(lin, file_buff) file_buff.seek(0) lin2 = torch.load(file_buff) self.assertEqual(lin._weight_bias(), lin2._weight_bias()) # Serialize -> Deserialize -> Serialize should match Serialize self.assertEqual( serialized, lin2._packed_params._packed_params.__getstate__()) # Test that op output is preserved by serialize -> deserialize if qengine_is_qnnpack(): x = torch.rand(size=(1, weight.shape[1])) y1 = lin(x) y2 = lin2(x) self.assertEqual(y1, y2)
def post_training_sparse_quantize(model, data_sparsifier_class, sparsify_first=True, select_embeddings: List[nn.Module] = None, **sparse_config): """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. Args: - model (nn.Module) model whose embeddings needs to be sparsified - data_sparsifier_class (type of data sparsifier) Type of sparsification that needs to be applied to model - sparsify_first (bool) if true, sparsifies first and then quantizes otherwise, quantizes first and then sparsifies. - select_embeddings (List of Embedding modules) List of embedding modules to in the model to be sparsified & quantized. If None, all embedding modules with be sparsified - sparse_config (Dict) config that will be passed to the constructor of data sparsifier object. Note: 1. When `sparsify_first=False`, quantization occurs first followed by sparsification. - before sparsifying, the embedding layers are dequantized. - scales and zero-points are saved - embedding layers are sparsified and `squash_mask` is applied - embedding weights are requantized using the saved scales and zero-points 2. When `sparsify_first=True`, sparsification occurs first followed by quantization. - embeddings are sparsified first - quantization is applied on the sparsified embeddings """ data_sparsifier = data_sparsifier_class(**sparse_config) # if select_embeddings is None, perform it on all embeddings if select_embeddings is None: embedding_modules = _fetch_all_embeddings(model) else: embedding_modules = [] assert isinstance( select_embeddings, List), "the embedding_modules must be a list of embedding modules" for emb in select_embeddings: assert type( emb ) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags" fqn_name = module_to_fqn(model, emb) assert fqn_name is not None, "the embedding modules must be part of input model" embedding_modules.append((fqn_name, emb)) if sparsify_first: # sparsify for name, emb_module in embedding_modules: valid_name = name.replace('.', '_') data_sparsifier.add_data(name=valid_name, data=emb_module) data_sparsifier.step() data_sparsifier.squash_mask() # quantize for _, emb_module in embedding_modules: emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True) else: # quantize for _, emb_module in embedding_modules: emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True) # retrieve scale & zero_points quantize_params: Dict[str, Dict] = { 'scales': {}, 'zero_points': {}, 'dequant_weights': {}, 'axis': {}, 'dtype': {} } for name, _ in embedding_modules: quantized_emb = fqn_to_module(model, name) assert quantized_emb is not None # satisfy mypy quantized_weight = quantized_emb.weight() # type: ignore[operator] quantize_params['scales'][ name] = quantized_weight.q_per_channel_scales() quantize_params['zero_points'][ name] = quantized_weight.q_per_channel_zero_points() quantize_params['dequant_weights'][name] = torch.dequantize( quantized_weight) quantize_params['axis'][ name] = quantized_weight.q_per_channel_axis() quantize_params['dtype'][name] = quantized_weight.dtype # attach data to sparsifier data_sparsifier.add_data( name=name.replace('.', '_'), data=quantize_params['dequant_weights'][name]) data_sparsifier.step() data_sparsifier.squash_mask() for name, _ in embedding_modules: quantized_emb = fqn_to_module(model, name) assert quantized_emb is not None # satisfy mypy requantized_vector = torch.quantize_per_channel( quantize_params['dequant_weights'][name], scales=quantize_params['scales'][name], zero_points=quantize_params['zero_points'][name], dtype=quantize_params['dtype'][name], axis=quantize_params['axis'][name]) quantized_emb.set_weight( requantized_vector) # type: ignore[operator]
def dequantize(*, input): return torch.dequantize(input)