Exemplo n.º 1
0
    def test_forward_per_channel_half_precision_numerics(self):
        scale = torch.randn(5).abs()
        zero = torch.randn(5).to(dtype=torch.int)
        axis = 1
        mini = 0
        maxi = 255

        for i in range(20):
            X1 = torch.randn(4, 5).to(torch.float16)
            Y1 = torch.fake_quantize_per_channel_affine(X1, scale, zero, axis, mini, maxi)
            Y1r = _fake_quantize_per_channel_affine_reference(X1, scale, zero, axis, mini, maxi)
            self.assertTrue(torch.allclose(Y1, Y1r, rtol=tolerance, atol=tolerance))

        # to force overflow
        X2 = torch.randn(4, 5).to(torch.float16)
        X2[0, 0] = 2**15 + .01
        Y2 = torch.fake_quantize_per_channel_affine(X2, scale, zero, axis, mini, maxi)
        Y2r = _fake_quantize_per_channel_affine_reference(X2, scale, zero, axis, mini, maxi)
        self.assertTrue(torch.allclose(Y2, Y2r, rtol=tolerance, atol=tolerance))

        scale = torch.zeros(5) + 10

        # to force underflow
        X3 = torch.randn(4, 5).to(torch.float16)
        X3[0, 0] = 2**-24
        Y3 = torch.fake_quantize_per_channel_affine(X3, scale, zero, axis, mini, maxi)
        Y3r = _fake_quantize_per_channel_affine_reference(X3, scale, zero, axis, mini, maxi)
        self.assertTrue(torch.allclose(Y3, Y3r, rtol=tolerance, atol=tolerance))
Exemplo n.º 2
0
    def test_fq_module_per_channel(self, device, X):
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        fq_module = FakeQuantize(default_per_channel_weight_observer,
                                 quant_min,
                                 quant_max,
                                 ch_axis=axis).to(device)
        Y_prime = fq_module(X)
        assert fq_module.scale is not None
        assert fq_module.zero_point is not None
        Y = _fake_quantize_per_channel_affine_reference(
            X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(Y.cpu().detach().numpy(),
                                   Y_prime.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

        # Test backward
        dout = torch.rand_like(X, dtype=torch.float, device=device)
        Y_prime.backward(dout)
        dX = _fake_quantize_per_channel_affine_grad_reference(
            dout, X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(dX.cpu().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)
Exemplo n.º 3
0
    def _test_learnable_forward_per_channel(self, X_base, device, scale_base,
                                            zero_point_base, axis):
        r"""Tests the forward path of the learnable FakeQuantizePerTensorAffine op.
        """
        for n_bits in (4, 8):
            quant_min, quant_max = 0, 2**(n_bits) - 1

            scale_base = scale_base.to(device)
            zero_point_base = zero_point_base.to(device)

            X_curr = X_base.clone()
            scale_curr = scale_base.clone()
            zero_point_curr = zero_point_base.clone()

            Y = _fake_quantize_per_channel_affine_reference(
                X_curr, scale_curr,
                zero_point_curr.round().clamp(quant_min, quant_max), axis,
                quant_min, quant_max).to(device)
            for grad_factor in [0.1, 1.0, 10.0]:
                Y_prime = torch._fake_quantize_learnable_per_channel_affine(
                    X_curr, scale_curr, zero_point_curr, axis, quant_min,
                    quant_max, grad_factor).to(device)
                self.assertTrue(
                    torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
                    "Expected kernel forward function to have results match the reference forward function"
                )
Exemplo n.º 4
0
    def _test_forward_per_channel_cachemask_impl(self, device):
        torch_types = (torch.qint8, torch.quint8)
        float_types = (torch.float32, torch.float16, torch.float64)
        for torch_type, float_type in itertools.product(
                torch_types, float_types):
            X = torch.randn(1, 2, 4, 4, dtype=float_type).to(device)
            # pick the scale + zp so that some values get clipped
            axis = 1
            obs = torch.quantization.PerChannelMinMaxObserver(
                axis, torch_type).to(device)
            obs(X * 0.75)
            scale, zero_point = obs.calculate_qparams()
            # TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast
            zero_point = zero_point.to(torch.int64)
            quant_min, quant_max = obs._calculate_qmin_qmax()

            Y = _fake_quantize_per_channel_affine_reference(
                X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min,
                quant_max)
            Y_prime = torch.fake_quantize_per_channel_affine(
                X, scale, zero_point, axis, quant_min, quant_max)
            np.testing.assert_allclose(Y,
                                       Y_prime.cpu(),
                                       rtol=tolerance,
                                       atol=tolerance)
            self.assertTrue(Y.dtype == float_type)
Exemplo n.º 5
0
    def test_forward_per_channel(self, device, X):
        r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        scale = to_tensor(scale, device)
        zero_point = torch.tensor(zero_point).to(dtype=torch.int32, device=device)
        Y = _fake_quantize_per_channel_affine_reference(X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
        Y_prime = torch.fake_quantize_per_channel_affine(
            X, scale, zero_point, axis, quant_min, quant_max)
        np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)