Ejemplo n.º 1
0
    def forward(self, X):
        if self.static_enabled[0] == 1:  # type: ignore[index]
            self.activation_post_process(X.detach())
            _scale, _zero_point = self.activation_post_process.calculate_qparams(
            )
            _scale = _scale.to(self.scale.device)
            _zero_point = _zero_point.to(self.zero_point.device)
            self.scale.data.copy_(_scale)
            self.zero_point.data.copy_(_zero_point)
        else:
            self.scale.data.clamp_(
                min=self.eps.item())  # type: ignore[operator]

        if self.fake_quant_enabled[0] == 1:
            if self.qscheme in (torch.per_channel_symmetric,
                                torch.per_tensor_symmetric):
                self.zero_point.data.zero_()

            if self.use_grad_scaling:
                grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5
            else:
                grad_factor = 1.0
            if self.qscheme in (torch.per_channel_symmetric,
                                torch.per_channel_affine):
                X = torch._fake_quantize_learnable_per_channel_affine(
                    X, self.scale, self.zero_point, self.ch_axis,
                    self.quant_min, self.quant_max, grad_factor)
            else:
                X = torch._fake_quantize_learnable_per_tensor_affine(
                    X, self.scale, self.zero_point, self.quant_min,
                    self.quant_max, grad_factor)

        return X
Ejemplo n.º 2
0
    def _test_learnable_forward_per_channel(self, X_base, device, scale_base,
                                            zero_point_base, axis):
        r"""Tests the forward path of the learnable FakeQuantizePerTensorAffine op.
        """
        for n_bits in (4, 8):
            quant_min, quant_max = 0, 2**(n_bits) - 1

            scale_base = scale_base.to(device)
            zero_point_base = zero_point_base.to(device)

            X_curr = X_base.clone()
            scale_curr = scale_base.clone()
            zero_point_curr = zero_point_base.clone()

            Y = _fake_quantize_per_channel_affine_reference(
                X_curr, scale_curr,
                zero_point_curr.round().clamp(quant_min, quant_max), axis,
                quant_min, quant_max).to(device)
            for grad_factor in [0.1, 1.0, 10.0]:
                Y_prime = torch._fake_quantize_learnable_per_channel_affine(
                    X_curr, scale_curr, zero_point_curr, axis, quant_min,
                    quant_max, grad_factor).to(device)
                self.assertTrue(
                    torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
                    "Expected kernel forward function to have results match the reference forward function"
                )
Ejemplo n.º 3
0
    def _test_learnable_backward_per_channel(self, X_base, device, scale_base, zero_point_base, axis):
        r"""Tests the backward path of the learnable FakeQuantizePerTensorAffine op.
        """
        for n_bits in (4, 8):
            quant_min, quant_max = 0, 2 ** n_bits - 1

            scale_base = scale_base.to(device)
            zero_point_base = zero_point_base.to(device=device)

            X_curr = X_base.clone()
            X_curr.requires_grad_()
            scale_curr = scale_base.clone()
            scale_curr.requires_grad_()
            zero_point_curr = zero_point_base.clone()
            zero_point_curr.requires_grad_()

            for grad_factor in [0.1, 1.0, 10.0]:
                Y_prime = torch._fake_quantize_learnable_per_channel_affine(
                    X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, grad_factor).to(device)

                dout = torch.rand(X_curr.shape, dtype=torch.float).to(device)
                dX, dScale, dZeroPoint = _fake_quantize_learnable_per_channel_affine_grad_reference(
                    dout, X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, device)
                Y_prime.backward(dout)

                dX_expected = dX.to(device).detach()
                dX_actual = X_curr.to(device).grad.detach()
                dScale_expected = dScale.to(device).detach()
                dScale_actual = scale_curr.to(device).grad.detach()
                dZeroPoint_expected = dZeroPoint.to(device).detach()
                dZeroPoint_actual = zero_point_curr.to(device).grad.detach()
                tolerance = 1e-4

                self.assertTrue(
                    torch.allclose(dX_expected, dX_actual, rtol=tolerance, atol=tolerance),
                    "Expected dX={} to match X.grad={}, X={}, s={}, z={}, dout={}, n_bits={}".format(
                        dX_expected, dX_actual, X_curr, scale_curr, zero_point_curr, dout, n_bits))
                self.assertTrue(
                    torch.allclose(dScale_expected * grad_factor, dScale_actual, rtol=tolerance, atol=tolerance),
                    "Expected dScale={} to match scale.grad={}, X={}, s={}, z={}, dout={}, n_bits={}".format(
                        dScale_expected * grad_factor, dScale_actual,
                        X_curr, scale_curr, zero_point_curr, dout, n_bits))
                self.assertTrue(
                    torch.allclose(dZeroPoint_expected * grad_factor, dZeroPoint_actual, rtol=tolerance, atol=tolerance),
                    "Expected dZeroPoint={} to match zero_point.grad={}, X={}, s={}, z={}, dout={}, n_bits={}".format(
                        dZeroPoint_expected * grad_factor, dZeroPoint_actual,
                        X_curr, scale_curr, zero_point_curr, dout, n_bits))
                X_curr.grad.data.zero_()
                scale_curr.grad.data.zero_()
                zero_point_curr.grad.data.zero_()
Ejemplo n.º 4
0
def fakeQuantizePerChannelLearnableKernel(input, scale, zero_point, axis: int,
                                          quant_min: int, quant_max: int):
    return torch._fake_quantize_learnable_per_channel_affine(
        input, scale, zero_point, axis, quant_min, quant_max)
 def forward(self):
     return torch._fake_quantize_learnable_per_channel_affine(
         self.input, self.scale, self.zero_point, self.axis, self.quant_min,
         self.quant_max)