def test_fq_module_per_channel(self, device, X): np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, axis, torch_type) = X quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max X = to_tensor(X, device) X.requires_grad_() fq_module = FakeQuantize(default_per_channel_weight_observer, quant_min, quant_max, ch_axis=axis).to(device) Y_prime = fq_module(X) assert fq_module.scale is not None assert fq_module.zero_point is not None Y = _fake_quantize_per_channel_affine_reference( X, fq_module.scale, fq_module.zero_point, axis, quant_min, quant_max) np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) # Test backward dout = torch.rand_like(X, dtype=torch.float, device=device) Y_prime.backward(dout) dX = _fake_quantize_per_channel_affine_grad_reference( dout, X, fq_module.scale, fq_module.zero_point, axis, quant_min, quant_max) np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
def _test_backward_per_channel_cachemask_impl(self, device): torch_types = (torch.qint8, torch.quint8) float_types = (torch.float32, torch.float16, torch.float64) for torch_type, float_type in itertools.product( torch_types, float_types): X = torch.randn(1, 2, 4, 4, dtype=float_type).to(device) # pick the scale + zp so that some values get clipped axis = 1 obs = torch.quantization.PerChannelMinMaxObserver( axis, torch_type).to(device) obs(X * 0.75) scale, zero_point = obs.calculate_qparams() # TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast zero_point = zero_point.to(torch.int64) quant_min, quant_max = obs._calculate_qmin_qmax() X.requires_grad_() Y_prime = torch.fake_quantize_per_channel_affine( X, scale, zero_point, axis, quant_min, quant_max) dout = torch.rand_like(X, dtype=float_type).to(device) dX = _fake_quantize_per_channel_affine_grad_reference( dout, X, scale, zero_point, axis, quant_min, quant_max) Y_prime.backward(dout) np.testing.assert_allclose(dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) assert (X.grad.dtype == float_type)
def _fake_quantize_learnable_per_channel_affine_grad_reference( dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max, device): r"""This method references the following literatures for back propagation on scale and zero point. - https://arxiv.org/pdf/1902.08153.pdf - https://arxiv.org/pdf/1903.08066.pdf """ per_channel_zero_point = ((per_channel_zero_point.detach() + 0.5).clamp(quant_min, quant_max)).type(torch.int32) grad_X = _fake_quantize_per_channel_affine_grad_reference( dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max).to(device) per_channel_scale = per_channel_scale.detach().type(torch.float) grad_scale = torch.zeros([per_channel_scale.size(0)]).to(device) grad_zero_point = torch.zeros([per_channel_zero_point.size(0)]).to(device) X_flattened = torch.unbind(X, dim=axis) dY_flattened = torch.unbind(dY, dim=axis) for i, X_i in enumerate(torch.unbind(X, dim=axis), 0): scale_i = per_channel_scale[i] zero_point_i = per_channel_zero_point[i] X_i = X_flattened[i] dY_i = dY_flattened[i] Xq_i = ((X_i / scale_i) + zero_point_i).round() Xfq_i = (Xq_i - zero_point_i) * scale_i indicate_small_scale_i = (Xq_i < quant_min).float().to(device) indicate_big_scale_i = (Xq_i > quant_max).float().to(device) indicate_middle_scale_i = torch.ones(indicate_small_scale_i.shape).to(device) - \ indicate_small_scale_i - indicate_big_scale_i indicate_saturate_zp_i = ((Xq_i < quant_min).float() + (Xq_i > quant_max).float()).to(device) indicate_unsaturate_zp_i = torch.ones(indicate_saturate_zp_i.shape).to(device) - \ indicate_saturate_zp_i Xq_i = Xq_i.clamp(quant_min, quant_max) Xfq_i = (Xq_i - zero_point_i) * scale_i grad_small_scale_i = quant_min - zero_point_i grad_big_scale_i = quant_max - zero_point_i grad_middle_scale_i = ((Xfq_i - X_i) / scale_i).to(device) grad_saturate_zp_i = -scale_i.to(device) grad_unsaturate_zp_i = 0 grad_scale_i = indicate_small_scale_i * grad_small_scale_i + \ indicate_middle_scale_i * grad_middle_scale_i + \ indicate_big_scale_i * grad_big_scale_i grad_zp_i = indicate_saturate_zp_i * grad_saturate_zp_i + \ indicate_unsaturate_zp_i * grad_unsaturate_zp_i grad_scale_i = (grad_scale_i * dY_i).sum().unsqueeze(dim=0) grad_zp_i = (grad_zp_i * dY_i).sum().unsqueeze(dim=0) grad_scale[i] = grad_scale_i grad_zero_point[i] = grad_zp_i return grad_X, grad_scale, grad_zero_point
def test_backward_per_channel(self, device, X): r"""Tests the backward method. """ np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, axis, torch_type) = X quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max X = to_tensor(X, device) scale = to_tensor(scale, device) zero_point = torch.tensor(zero_point).to(dtype=torch.int32, device=device) X.requires_grad_() Y_prime = torch.fake_quantize_per_channel_affine( X, scale, zero_point, axis, quant_min, quant_max) dout = torch.rand_like(X, dtype=torch.float).to(device) dX = _fake_quantize_per_channel_affine_grad_reference( dout, X, scale, zero_point, axis, quant_min, quant_max) Y_prime.backward(dout) np.testing.assert_allclose(dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)