def forward(self, X): if self.static_enabled[0] == 1: # type: ignore[index] self.activation_post_process(X.detach()) _scale, _zero_point = self.activation_post_process.calculate_qparams( ) _scale = _scale.to(self.scale.device) _zero_point = _zero_point.to(self.zero_point.device) self.scale.data.copy_(_scale) self.zero_point.data.copy_(_zero_point) else: self.scale.data.clamp_( min=self.eps.item()) # type: ignore[operator] if self.fake_quant_enabled[0] == 1: if self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric): self.zero_point.data.zero_() if self.use_grad_scaling: grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5 else: grad_factor = 1.0 if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): X = torch._fake_quantize_learnable_per_channel_affine( X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max, grad_factor) else: X = torch._fake_quantize_learnable_per_tensor_affine( X, self.scale, self.zero_point, self.quant_min, self.quant_max, grad_factor) return X
def _test_learnable_backward_per_tensor(self, X, device, scale_base, zero_point_base): r"""Tests the backward method with additional backprop support for scale and zero point. """ X_base = torch.tensor(X).to(device) for n_bits in (4, 8): quant_min, quant_max = 0, 2**n_bits - 1 X = X_base.clone().float().to(device) X.requires_grad_() scale_base = scale_base.to(device) zero_point_base = zero_point_base.to(device) scale = scale_base.clone() scale.requires_grad_() zero_point = zero_point_base.clone().clamp(quant_min, quant_max) zero_point.requires_grad_() for grad_factor in [0.1, 1.0, 10.0]: Y_prime = torch._fake_quantize_learnable_per_tensor_affine( X, scale, zero_point, quant_min, quant_max, grad_factor).to(device) dout = torch.rand_like(X, dtype=torch.float).to(device) dX, dScale, dZeroPoint = _fake_quantize_learnable_per_tensor_affine_grad_reference( dout, X, scale, zero_point, quant_min, quant_max, device) Y_prime.backward(dout) expected_dX = dX.to(device).detach() actual_dX = X.grad.to(device).detach() expected_dScale = dScale.to(device).detach() actual_dScale = scale.grad.to(device).detach() expected_dZeroPoint = dZeroPoint.to(device).detach() actual_dZeroPoint = zero_point.grad.to(device).detach() self.assertTrue( torch.allclose(expected_dX, actual_dX, rtol=tolerance, atol=tolerance), "Expected dX to match X.grad") self.assertTrue( torch.allclose(expected_dScale * grad_factor, actual_dScale, rtol=tolerance, atol=tolerance), "Expected dScale to match scale.grad") self.assertTrue( torch.allclose(expected_dZeroPoint * grad_factor, actual_dZeroPoint, rtol=tolerance, atol=tolerance), "Expected dZeroPoint to match zero_point.grad") X.grad.data.zero_() scale.grad.data.zero_() zero_point.grad.data.zero_()
def _test_learnable_forward_per_tensor(self, X, device, scale_base, zero_point_base): X_base = torch.tensor(X).to(device) for n_bits in (4, 8): quant_min, quant_max = 0, 2 ** n_bits - 1 X = X_base.clone().float() scale_base = scale_base.to(device).float() zero_point_base = zero_point_base.to(dtype=torch.int32, device=device) scale = scale_base.clone() zero_point = zero_point_base.clamp(quant_min, quant_max) Y = _fake_quantize_per_tensor_affine_reference( X, scale, zero_point, quant_min, quant_max).to(device) for grad_factor in [0.1, 1.0, 10.0]: Y_prime = torch._fake_quantize_learnable_per_tensor_affine( X, scale, zero_point, quant_min, quant_max, grad_factor).to(device) self.assertTrue( torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance), "Expected kernel forward function to have results match the reference forward function")
def fakeQuantizePerTensorLearnableKernel(input, scale, zero_point, quant_min: int, quant_max: int): return torch._fake_quantize_learnable_per_tensor_affine( input, scale, zero_point, quant_min, quant_max)
def forward(self): return torch._fake_quantize_learnable_per_tensor_affine( self.input, self.scale, self.zero_point, self.quant_min, self.quant_max)