def forward(self, X): if self.static_enabled[0] == 1: # type: ignore[index] self.activation_post_process(X.detach()) _scale, _zero_point = self.activation_post_process.calculate_qparams( ) _scale = _scale.to(self.scale.device) _zero_point = _zero_point.to(self.zero_point.device) self.scale.data.copy_(_scale) self.zero_point.data.copy_(_zero_point) else: self.scale.data.clamp_( min=self.eps.item()) # type: ignore[operator] if self.fake_quant_enabled[0] == 1: if self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric): self.zero_point.data.zero_() if self.use_grad_scaling: grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5 else: grad_factor = 1.0 if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): X = torch._fake_quantize_learnable_per_channel_affine( X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max, grad_factor) else: X = torch._fake_quantize_learnable_per_tensor_affine( X, self.scale, self.zero_point, self.quant_min, self.quant_max, grad_factor) return X
def _test_learnable_forward_per_channel(self, X_base, device, scale_base, zero_point_base, axis): r"""Tests the forward path of the learnable FakeQuantizePerTensorAffine op. """ for n_bits in (4, 8): quant_min, quant_max = 0, 2**(n_bits) - 1 scale_base = scale_base.to(device) zero_point_base = zero_point_base.to(device) X_curr = X_base.clone() scale_curr = scale_base.clone() zero_point_curr = zero_point_base.clone() Y = _fake_quantize_per_channel_affine_reference( X_curr, scale_curr, zero_point_curr.round().clamp(quant_min, quant_max), axis, quant_min, quant_max).to(device) for grad_factor in [0.1, 1.0, 10.0]: Y_prime = torch._fake_quantize_learnable_per_channel_affine( X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, grad_factor).to(device) self.assertTrue( torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance), "Expected kernel forward function to have results match the reference forward function" )
def _test_learnable_backward_per_channel(self, X_base, device, scale_base, zero_point_base, axis): r"""Tests the backward path of the learnable FakeQuantizePerTensorAffine op. """ for n_bits in (4, 8): quant_min, quant_max = 0, 2 ** n_bits - 1 scale_base = scale_base.to(device) zero_point_base = zero_point_base.to(device=device) X_curr = X_base.clone() X_curr.requires_grad_() scale_curr = scale_base.clone() scale_curr.requires_grad_() zero_point_curr = zero_point_base.clone() zero_point_curr.requires_grad_() for grad_factor in [0.1, 1.0, 10.0]: Y_prime = torch._fake_quantize_learnable_per_channel_affine( X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, grad_factor).to(device) dout = torch.rand(X_curr.shape, dtype=torch.float).to(device) dX, dScale, dZeroPoint = _fake_quantize_learnable_per_channel_affine_grad_reference( dout, X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, device) Y_prime.backward(dout) dX_expected = dX.to(device).detach() dX_actual = X_curr.to(device).grad.detach() dScale_expected = dScale.to(device).detach() dScale_actual = scale_curr.to(device).grad.detach() dZeroPoint_expected = dZeroPoint.to(device).detach() dZeroPoint_actual = zero_point_curr.to(device).grad.detach() tolerance = 1e-4 self.assertTrue( torch.allclose(dX_expected, dX_actual, rtol=tolerance, atol=tolerance), "Expected dX={} to match X.grad={}, X={}, s={}, z={}, dout={}, n_bits={}".format( dX_expected, dX_actual, X_curr, scale_curr, zero_point_curr, dout, n_bits)) self.assertTrue( torch.allclose(dScale_expected * grad_factor, dScale_actual, rtol=tolerance, atol=tolerance), "Expected dScale={} to match scale.grad={}, X={}, s={}, z={}, dout={}, n_bits={}".format( dScale_expected * grad_factor, dScale_actual, X_curr, scale_curr, zero_point_curr, dout, n_bits)) self.assertTrue( torch.allclose(dZeroPoint_expected * grad_factor, dZeroPoint_actual, rtol=tolerance, atol=tolerance), "Expected dZeroPoint={} to match zero_point.grad={}, X={}, s={}, z={}, dout={}, n_bits={}".format( dZeroPoint_expected * grad_factor, dZeroPoint_actual, X_curr, scale_curr, zero_point_curr, dout, n_bits)) X_curr.grad.data.zero_() scale_curr.grad.data.zero_() zero_point_curr.grad.data.zero_()
def fakeQuantizePerChannelLearnableKernel(input, scale, zero_point, axis: int, quant_min: int, quant_max: int): return torch._fake_quantize_learnable_per_channel_affine( input, scale, zero_point, axis, quant_min, quant_max)
def forward(self): return torch._fake_quantize_learnable_per_channel_affine( self.input, self.scale, self.zero_point, self.axis, self.quant_min, self.quant_max)