class TestFakeQuantize(TestCase):
    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.per_channel_tensor(shapes=hu.array_shapes(
               2,
               5,
           ),
                                   qparams=hu.qparams(dtypes=torch.qint8)))
    def test_fq_module_per_channel(self, device, X):
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        fq_module = FakeQuantize(default_per_channel_weight_observer,
                                 quant_min,
                                 quant_max,
                                 ch_axis=axis).to(device)
        Y_prime = fq_module(X)
        assert fq_module.scale is not None
        assert fq_module.zero_point is not None
        Y = _fake_quantize_per_channel_affine_reference(
            X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(Y.cpu().detach().numpy(),
                                   Y_prime.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

        # Test backward
        dout = torch.rand_like(X, dtype=torch.float, device=device)
        Y_prime.backward(dout)
        dX = _fake_quantize_per_channel_affine_grad_reference(
            dout, X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(dX.cpu().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    def test_fq_serializable_per_channel(self):
        observer = default_per_channel_weight_observer
        quant_min = -128
        quant_max = 127
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        X = torch.tensor(
            [[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]],
            dtype=torch.float32)
        y_ref = fq_module(X)
        state_dict = fq_module.state_dict()
        self.assertEqual(state_dict['scale'], [0.054902, 0.078431])
        self.assertEqual(state_dict['zero_point'], [0, 0])
        b = io.BytesIO()
        torch.save(state_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in state_dict:
            self.assertEqual(state_dict[key], loaded_dict[key])
Example #2
0
class TestQuantizedTensor(TestCase):
    def test_qtensor(self):
        num_elements = 10
        scale = 1.0
        zero_point = 2
        for device in get_supported_device_types():
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                r = torch.ones(num_elements, dtype=torch.float, device=device)
                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
                self.assertEqual(qr.q_scale(), scale)
                self.assertEqual(qr.q_zero_point(), zero_point)
                self.assertTrue(qr.is_quantized)
                self.assertFalse(r.is_quantized)
                self.assertEqual(qr.qscheme(), torch.per_tensor_affine)
                self.assertTrue(isinstance(qr.qscheme(), torch.qscheme))
                # slicing and int_repr
                int_repr = qr.int_repr()
                for num in int_repr:
                    self.assertEqual(num, 3)
                for num in qr[2:].int_repr():
                    self.assertEqual(num, 3)
                # dequantize
                rqr = qr.dequantize()
                for i in range(num_elements):
                    self.assertEqual(r[i], rqr[i])
                # we can also print a qtensor
                empty_r = torch.ones((0, 1), dtype=torch.float, device=device)
                empty_qr = torch.quantize_per_tensor(empty_r, scale,
                                                     zero_point, dtype)

                device_msg = "" if device == 'cpu' else "device='" + device + ":0', "
                dtype_msg = str(dtype) + ", "
                self.assertEqual(
                    ' '.join(str(empty_qr).split()), "tensor([], " +
                    device_msg + "size=(0, 1), dtype=" + dtype_msg +
                    "quantization_scheme=torch.per_tensor_affine, " +
                    "scale=1.0, zero_point=2)")

    def test_qtensor_float_assignment(self):
        # Scalar Tensor
        # item
        scale = 1.0
        zero_point = 2
        r = torch.ones(1, dtype=torch.float)
        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
            self.assertEqual(qr.item(), 1)
            self.assertEqual(qr[0].item(), 1)
            # assignment
            self.assertTrue(qr[0].is_quantized)
            qr[0] = 11.3  # float assignment
            self.assertEqual(qr.item(), 11)
            x = torch.ones(1, dtype=torch.float) * 15.3
            # Copying from a float Tensor
            qr[:] = x
            self.assertEqual(qr.item(), 15)

            dtype_msg = str(dtype) + ", "
            self.assertEqual(
                ' '.join(str(qr).split()), "tensor([15.], size=(1,), dtype=" +
                dtype_msg + "quantization_scheme=torch.per_tensor_affine, " +
                "scale=1.0, zero_point=2)")

    def test_qtensor_quant_dequant(self):
        scale = 0.02
        zero_point = 2
        for device in get_supported_device_types():
            r = torch.rand(3, 2, 4, 5, dtype=torch.float,
                           device=device) * 4 - 2
            for memory_format in [
                    torch.contiguous_format, torch.channels_last
            ]:
                r = r.contiguous(memory_format=memory_format)
                for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                    qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
                    rqr = qr.dequantize()
                    self.assertTrue(
                        np.allclose(r.cpu().numpy(),
                                    rqr.cpu().numpy(),
                                    atol=2 / scale))
        # Also check 5D tensors work.
        for device in get_supported_device_types():
            r = torch.rand(3, 2, 4, 5, 6, dtype=torch.float,
                           device=device) * 4 - 2
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
                rqr = qr.dequantize()
                self.assertTrue(
                    np.allclose(r.cpu().numpy(),
                                rqr.cpu().numpy(),
                                atol=2 / scale))

    # legacy constructor/new doesn't support qtensors
    def test_qtensor_legacy_new_failure(self):
        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
        scale = 0.02
        zero_point = 2
        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.quint8)
        self.assertRaises(RuntimeError, lambda: qr.new(device='cpu'))
        self.assertRaises(RuntimeError, lambda: qr.new(r.storage()))
        self.assertRaises(RuntimeError, lambda: qr.new(r))
        self.assertRaises(RuntimeError, lambda: qr.new(torch.Size([2, 3])))
        self.assertRaises(RuntimeError, lambda: qr.new([6]))

    def test_per_channel_qtensor_creation(self):
        numel = 10
        ch_axis = 0
        scales = torch.rand(numel)
        zero_points_int = torch.randint(0, 10, size=(numel, ))
        zero_points_float = torch.randn(numel)
        for dtype, zero_points in itertools.product(
            [torch.qint8, torch.quint8], [zero_points_float, zero_points_int]):
            q = torch._empty_per_channel_affine_quantized(
                [numel],
                scales=scales,
                zero_points=zero_points,
                axis=ch_axis,
                dtype=dtype)
            # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
            self.assertEqualIgnoreType(scales, q.q_per_channel_scales())
            self.assertEqual(zero_points, q.q_per_channel_zero_points())
            self.assertEqual(ch_axis, q.q_per_channel_axis())

        # create Tensor from uint8_t Tensor, scales and zero_points
        for zero_points in [zero_points_float, zero_points_int]:
            int_tensor = torch.randint(0,
                                       100,
                                       size=(numel, ),
                                       dtype=torch.uint8)
            q = torch._make_per_channel_quantized_tensor(
                int_tensor, scales, zero_points, ch_axis)
            self.assertEqual(int_tensor, q.int_repr())
            # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
            self.assertEqualIgnoreType(scales, q.q_per_channel_scales())
            self.assertEqual(zero_points, q.q_per_channel_zero_points())
            self.assertEqual(ch_axis, q.q_per_channel_axis())

    def test_qtensor_creation(self):
        scale = 0.5
        zero_point = 10
        numel = 10
        for device in get_supported_device_types():
            q = torch._empty_affine_quantized([numel],
                                              scale=scale,
                                              zero_point=zero_point,
                                              device=device,
                                              dtype=torch.quint8)
            self.assertEqual(scale, q.q_scale())
            self.assertEqual(zero_point, q.q_zero_point())

            # create Tensor from uint8_t Tensor, scale and zero_point
            int_tensor = torch.randint(0,
                                       100,
                                       size=(10, ),
                                       device=device,
                                       dtype=torch.uint8)
            q = torch._make_per_tensor_quantized_tensor(
                int_tensor, scale, zero_point)
            self.assertEqual(int_tensor, q.int_repr())
            self.assertEqual(scale, q.q_scale())
            self.assertEqual(zero_point, q.q_zero_point())

            # create via empty_like
            q = torch._empty_affine_quantized([numel],
                                              scale=scale,
                                              zero_point=zero_point,
                                              device=device,
                                              dtype=torch.quint8)
            q_el = torch.empty_like(q)
            self.assertEqual(q.q_scale(), q_el.q_scale())
            self.assertEqual(q.q_zero_point(), q_el.q_zero_point())
            self.assertEqual(q.dtype, q_el.dtype)

            # create via empty_like but change the dtype (currently not supported)
            with self.assertRaises(RuntimeError):
                torch.empty_like(q, dtype=torch.qint8)

    def test_qtensor_dtypes(self):
        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
        scale = 0.2
        zero_point = 2
        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.qint8)
        rqr = qr.dequantize()
        self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.quint8)
        rqr = qr.dequantize()
        self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.qint32)
        rqr = qr.dequantize()
        self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))

    def _test_quantize_per_channel(self, r, scales, zero_points, axis,
                                   float_params):
        def _quantize_per_channel_ref_nd(data, scales, zero_points,
                                         float_params):
            dims = data.size()
            data = data.view(-1, dims[axis], np.prod(dims[axis + 1:]))
            res = torch.empty_like(data)
            quant_min, quant_max = 0, 255
            for i in range(res.size()[0]):
                for j in range(res.size()[1]):
                    for k in range(res.size()[2]):
                        if float_params:
                            inv_scale = 1.0 / scales[j]
                            res[i][j][k] = np.clip(
                                np.round(data[i][j][k] * inv_scale +
                                         zero_points[j]), quant_min, quant_max)
                        else:
                            res[i][j][k] = np.clip(
                                np.round(data[i][j][k] / scales[j]) +
                                zero_points[j], quant_min, quant_max)
            res = res.view(*dims)
            return res

        contig_format = torch.channels_last if r.ndim == 4 else torch.channels_last_3d
        for memory_format in [torch.contiguous_format, contig_format]:
            ref_res = _quantize_per_channel_ref_nd(r, scales, zero_points,
                                                   float_params)
            r_contig = r.contiguous(memory_format=memory_format)
            qr = torch.quantize_per_channel(r_contig, scales, zero_points,
                                            axis, torch.quint8)
            rqr = qr.dequantize()
            self.assertTrue(np.allclose(qr.int_repr(), ref_res))
            self.assertTrue(
                np.allclose(r.numpy(),
                            rqr.numpy(),
                            atol=2 / np.min(scales.numpy())))

    def test_qtensor_quantize_per_channel(self):
        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
        scales = torch.tensor([0.2, 0.03], dtype=torch.double)
        zero_points = torch.tensor([5, 10], dtype=torch.long)
        axis = 1

        def quantize_c(data, scales, zero_points):
            res = torch.empty((3, 2))
            quant_min, quant_max = 0, 255
            for i in range(3):
                for j in range(2):
                    res[i][j] = np.clip(
                        np.round(data[i][j] / scales[j]) + zero_points[j],
                        quant_min, quant_max)
            return res

        qr = torch.quantize_per_channel(r, scales, zero_points, axis,
                                        torch.quint8)
        rqr = qr.dequantize()
        self.assertTrue(
            np.allclose(qr.int_repr(), quantize_c(r, scales, zero_points)))
        self.assertTrue(
            np.allclose(r.numpy(),
                        rqr.numpy(),
                        atol=2 / np.min(scales.numpy())))

        # Check 4D tensor with 2 different memory formats.
        r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4 - 2
        scales = torch.tensor([0.2, 0.03], dtype=torch.double)
        zero_points = torch.tensor([5, 10], dtype=torch.long)
        self._test_quantize_per_channel(r, scales, zero_points, 1, False)

        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.double)
        zero_points = torch.tensor([5, 10, 7], dtype=torch.long)
        self._test_quantize_per_channel(r, scales, zero_points, 0, False)

        # Check 5D tensor.
        r = torch.rand(3, 2, 4, 5, 7, dtype=torch.float) * 4 - 2
        scales = torch.tensor([0.2, 0.03], dtype=torch.double)
        zero_points = torch.tensor([5, 10], dtype=torch.long)
        self._test_quantize_per_channel(r, scales, zero_points, 1, False)

        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.double)
        zero_points = torch.tensor([5, 10, 7], dtype=torch.long)
        self._test_quantize_per_channel(r, scales, zero_points, 0, False)

    def test_quantize_per_channel_float_qparams(self):
        r = torch.rand(3, 2, dtype=torch.float) * 4
        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
        axis = 1

        # Reference quantize function with FP zero_point.
        def quantize_ref(data, scales, zero_points):
            res = torch.empty((3, 2))
            quant_min, quant_max = 0, 255
            for i in range(3):
                for j in range(2):
                    inv_scale = 1.0 / scales[j]
                    res[i][j] = np.clip(
                        np.round(data[i][j] * inv_scale + zero_points[j]),
                        quant_min, quant_max)
            return res

        qr = torch.quantize_per_channel(r, scales, zero_points, axis,
                                        torch.quint8)
        dequant_tensor = qr.dequantize()
        ref = quantize_ref(r, scales, zero_points)
        self.assertTrue(np.allclose(qr.int_repr(), ref))
        self.assertTrue(np.allclose(r.numpy(), dequant_tensor.numpy(), atol=1))

        # Check 4D tensor with 2 different memory formats.
        r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4
        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
        self._test_quantize_per_channel(r, scales, zero_points, 1, True)

        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
        self._test_quantize_per_channel(r, scales, zero_points, 0, True)

        # Check 5D tensor.
        r = torch.rand(3, 2, 4, 5, 7, dtype=torch.float) * 4 - 2
        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
        self._test_quantize_per_channel(r, scales, zero_points, 1, True)

        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
        self._test_quantize_per_channel(r, scales, zero_points, 0, True)

    def test_qtensor_permute(self):
        scale = 0.02
        zero_point = 1
        for device in get_supported_device_types():
            r = torch.rand(10, 30, 2, 2, device=device,
                           dtype=torch.float) * 4 - 2
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                qr = torch.quantize_per_tensor(r,
                                               scale,
                                               zero_point,
                                               dtype=dtype)
                qr = qr.transpose(0, 1)
                rqr = qr.dequantize()
                # compare transpose + dequantized result with orignal transposed result
                self.assertTrue(
                    np.allclose(r.cpu().numpy().transpose([1, 0, 2, 3]),
                                rqr.cpu().numpy(),
                                atol=2 / scale))

                qr = torch.quantize_per_tensor(r,
                                               scale,
                                               zero_point,
                                               dtype=dtype)
                qr1 = qr.permute([1, 0, 2, 3])
                qr2 = qr.transpose(0, 1)
                # compare int representation after transformations
                self.assertEqual(qr1.int_repr(), qr2.int_repr())
                self.assertEqual(qr1.q_scale(), qr2.q_scale())
                self.assertEqual(qr1.q_zero_point(), qr2.q_zero_point())
                # compare dequantized result
                self.assertEqual(qr1.dequantize(), qr2.dequantize())
                # compare permuted + dequantized result with original transposed result
                self.assertTrue(
                    np.allclose(qr2.dequantize().cpu().numpy(),
                                r.cpu().numpy().transpose([1, 0, 2, 3]),
                                atol=2 / scale))
                # make permuted result contiguous
                self.assertEqual(qr2.contiguous().int_repr(), qr2.int_repr())

                # change memory format
                qlast = qr.contiguous(memory_format=torch.channels_last)
                self.assertEqual(qr.stride(),
                                 list(reversed(sorted(qr.stride()))))
                self.assertNotEqual(qlast.stride(),
                                    list(reversed(sorted(qlast.stride()))))
                self.assertEqual(qr.int_repr(), qlast.int_repr())
                self.assertEqual(qr.q_scale(), qlast.q_scale())
                self.assertEqual(qr.q_zero_point(), qlast.q_zero_point())
                self.assertEqual(qlast.dequantize(), qr.dequantize())

                # permuting larger tensors
                x = torch.randn(64, 64, device=device)
                qx = torch.quantize_per_tensor(x, 1.0, 0, dtype)
                # should work
                qx.permute([1, 0])

    def test_qtensor_per_channel_permute(self):
        r = torch.rand(20, 10, 2, 2, dtype=torch.float) * 4 - 2
        dtype = torch.qint8
        scales = torch.rand(10) * 0.02 + 0.01
        zero_points = torch.round(torch.rand(10) * 2 - 1).to(torch.long)
        qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)

        # we can't reorder the axis
        with self.assertRaises(RuntimeError):
            qr.transpose(0, 1)

        # but we can change memory format
        qlast = qr.contiguous(memory_format=torch.channels_last)
        self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride()))))
        self.assertNotEqual(qlast.stride(),
                            list(reversed(sorted(qlast.stride()))))
        self.assertEqual(qr.int_repr(), qlast.int_repr())
        # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
        self.assertEqualIgnoreType(scales, qlast.q_per_channel_scales())
        self.assertEqual(zero_points, qlast.q_per_channel_zero_points())
        self.assertEqual(1, qlast.q_per_channel_axis())
        self.assertEqual(qlast.dequantize(), qr.dequantize())

    def test_qtensor_load_save(self):
        scale = 0.2
        zero_point = 10
        # storage is not accessible on the cuda right now
        device = "cpu"
        r = torch.rand(15, 2, dtype=torch.float32, device=device) * 2
        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
            qrv = qr[:, 1]
            with tempfile.NamedTemporaryFile() as f:
                # Serializing and Deserializing Tensor
                torch.save((qr, qrv), f)
                f.seek(0)
                qr2, qrv2 = torch.load(f)
                self.assertEqual(qr, qr2)
                self.assertEqual(qrv, qrv2)
                self.assertEqual(qr2.storage().data_ptr(),
                                 qrv2.storage().data_ptr())

    def test_qtensor_per_channel_load_save(self):
        r = torch.rand(20, 10, dtype=torch.float) * 4 - 2
        scales = torch.rand(10, dtype=torch.double) * 0.02 + 0.01
        zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long)
        # quint32, cuda is not supported yet
        for dtype in [torch.quint8, torch.qint8]:
            qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)
            with tempfile.NamedTemporaryFile() as f:
                # Serializing and Deserializing Tensor
                torch.save(qr, f)
                f.seek(0)
                qr2 = torch.load(f)
                self.assertEqual(qr, qr2)

    def test_qtensor_copy(self):
        scale = 0.5
        zero_point = 10
        numel = 10
        for device in get_supported_device_types():
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                # copy from same scale and zero_point
                q = torch._empty_affine_quantized([numel],
                                                  scale=scale,
                                                  zero_point=zero_point,
                                                  device=device,
                                                  dtype=dtype)
                q2 = torch._empty_affine_quantized([numel],
                                                   scale=scale,
                                                   zero_point=zero_point,
                                                   device=device,
                                                   dtype=dtype)
                q.copy_(q2)
                self.assertEqual(q.int_repr(), q2.int_repr())
                self.assertEqual(q.q_scale(), q2.q_scale())
                self.assertEqual(q.q_zero_point(), q2.q_zero_point())
                # copying from different scale and zero_point
                scale = 3.2
                zero_point = 5
                q = torch._empty_affine_quantized([numel],
                                                  scale=scale,
                                                  zero_point=zero_point,
                                                  device=device,
                                                  dtype=dtype)
                # check original scale and zero_points are set correctly
                self.assertEqual(q.q_scale(), scale)
                self.assertEqual(q.q_zero_point(), zero_point)
                q.copy_(q2)
                # check scale and zero_points has been copied
                self.assertEqual(q, q2)
                # can't copy from quantized tensor to non-quantized tensor
                r = torch.empty([numel], dtype=torch.float)
                q = torch._empty_affine_quantized([numel],
                                                  scale=scale,
                                                  zero_point=zero_point,
                                                  dtype=torch.quint8)
                with self.assertRaisesRegex(RuntimeError,
                                            "please use dequantize"):
                    r.copy_(q)

    def test_torch_qtensor_deepcopy(self):
        # cuda is not supported yet
        device = "cpu"
        q_int = torch.randint(0, 100, [3, 5], device=device, dtype=torch.uint8)
        scale, zero_point = 2.0, 3
        q = torch._make_per_tensor_quantized_tensor(q_int,
                                                    scale=scale,
                                                    zero_point=zero_point)
        qc = deepcopy(q)
        self.assertEqual(qc, q)

    def test_qtensor_clone(self):
        numel = 10
        scale = 0.5
        zero_point = 10
        for device in get_supported_device_types():
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                q2 = torch._empty_affine_quantized([numel],
                                                   scale=scale,
                                                   zero_point=zero_point,
                                                   device=device,
                                                   dtype=dtype)
                q = q2.clone()
                # Check to make sure the scale and zero_point has been copied.
                self.assertEqual(q, q2)

    def test_qtensor_fill(self):
        numel = 10
        scale = 0.5
        zero_point = 10

        ones = torch.ones(numel).to(torch.float)

        types = [torch.qint8, torch.quint8, torch.qint32]
        fills = [-1, 1, 2**32]  # positive, negative, overflow

        # `fill_` uses `copy_(float)`, which doesn't support CUDA
        device = 'cpu'
        ones = ones.to(device)
        for qtype, fill_with in itertools.product(types, fills):
            q_filled = torch._empty_affine_quantized([numel],
                                                     scale=scale,
                                                     zero_point=zero_point,
                                                     device=device,
                                                     dtype=qtype)
            q_filled.fill_(fill_with)
            int_repr = torch.quantize_per_tensor(ones * fill_with, scale,
                                                 zero_point, qtype)
            fill_with = int_repr.dequantize()
            int_repr = int_repr.int_repr()

            self.assertEqual(q_filled.int_repr(), int_repr)
            self.assertEqual(q_filled.dequantize(), fill_with)
            # Make sure the scale and zero_point don't change
            self.assertEqual(q_filled.q_scale(), scale)
            self.assertEqual(q_filled.q_zero_point(), zero_point)

    def test_qtensor_view(self):
        scale, zero_point, dtype = 1.0, 2, torch.uint8
        for device in get_supported_device_types():
            q_int = torch.randint(0,
                                  100, [1, 2, 3],
                                  device=device,
                                  dtype=dtype)
            q = torch._make_per_tensor_quantized_tensor(q_int,
                                                        scale=scale,
                                                        zero_point=zero_point)
            q2 = q.view(1, 3, 2)
            self.assertEqual(q.numel(), q2.numel())
            # testing -1
            self.assertEqual(q, q2.view(1, -1, 3))

            a_int = torch.randint(0,
                                  100, [1, 2, 3, 4],
                                  device=device,
                                  dtype=dtype)
            a = torch._make_per_tensor_quantized_tensor(a_int,
                                                        scale=scale,
                                                        zero_point=zero_point)
            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
            c = a.view(1, 3, 2, 4)  # does not change tensor layout in memory
            self.assertEqual(b.size(), c.size())
            self.assertEqual(b.q_scale(), c.q_scale())
            self.assertEqual(b.q_zero_point(), c.q_zero_point())
            self.assertNotEqual(b.stride(), c.stride())
            # size is the same but the underlying data is different
            self.assertNotEqual(b.int_repr(), c.int_repr())
            # torch.equal is not supported for the cuda backend
            if device == 'cpu':
                self.assertFalse(torch.equal(b, c))
            else:
                self.assertRaises(RuntimeError, lambda: torch.equal(b, c))

            # a case can't view non-contiguos Tensor
            a_int = torch.randint(0,
                                  100, [1, 2, 3, 4],
                                  device=device,
                                  dtype=dtype)
            a = torch._make_per_tensor_quantized_tensor(a_int,
                                                        scale=scale,
                                                        zero_point=zero_point)
            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
            err_str = "view size is not compatible with input tensor's size and stride*"
            with self.assertRaisesRegex(RuntimeError, err_str):
                b.view(1, 4, 2, 3)
            # view on contiguous tensor is fine
            b.contiguous().view(1, 4, 2, 3)

    def test_qtensor_resize(self):
        scale, zero_point, dtype = 1.0, 2, torch.uint8
        sizes1 = [1, 2, 3, 4]
        sizes2 = [1 * 2, 3 * 4]
        sizes3 = [1, 2 * 3, 4]
        sizes4 = [1 * 2 * 3 * 4]
        sizes5 = [1, 2, 1, 3, 1, 4]

        q1_int = torch.randint(0, 100, sizes1, dtype=dtype)
        q1 = torch._make_per_tensor_quantized_tensor(q1_int,
                                                     scale=scale,
                                                     zero_point=zero_point)
        q2 = q1.resize(*sizes2)
        q3 = q2.resize(*sizes3)
        q4 = q3.resize(*sizes4)
        q5 = q4.resize(*sizes5)

        self.assertEqual(q1.numel(), q2.numel())
        self.assertEqual(q1.numel(), q3.numel())
        self.assertEqual(q1.numel(), q4.numel())
        self.assertEqual(q1.numel(), q5.numel())

        # Compare original and post-transpose
        a_int = torch.randint(0, 100, sizes1, dtype=dtype)
        a = torch._make_per_tensor_quantized_tensor(a_int,
                                                    scale=scale,
                                                    zero_point=zero_point)
        b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
        c = b.resize(*sizes1)  # Change the sizes back to the original

        self.assertEqual(a.size(), c.size())
        self.assertEqual(b.q_scale(), c.q_scale())
        self.assertEqual(b.q_zero_point(), c.q_zero_point())
        self.assertNotEqual(b.stride(), c.stride())
        # size is the same but the underlying data is different
        self.assertNotEqual(b.int_repr(), c.int_repr())
        self.assertFalse(torch.equal(b, c))

        # Throws an error if numel is wrong
        q1_int = torch.randint(0, 100, sizes1, dtype=dtype)
        q1 = torch._make_per_tensor_quantized_tensor(a_int,
                                                     scale=scale,
                                                     zero_point=zero_point)
        err_str = "requested resize to*"
        with self.assertRaisesRegex(RuntimeError, err_str):
            q2 = q1.resize(*sizes1[:-1])
        # resize on both contiguous and non-contiguous tensor should be fine
        q3 = q1.resize(*sizes2)
        q4 = q1.contiguous().resize(*sizes2)

    def test_qtensor_reshape(self):
        scale, zero_point, dtype = 1.0, 2, torch.uint8
        for device in get_supported_device_types():
            q_int = torch.randint(0, 100, [3, 5], dtype=dtype, device=device)
            q = torch._make_per_tensor_quantized_tensor(q_int,
                                                        scale=scale,
                                                        zero_point=zero_point)
            q2 = q.reshape([15])
            self.assertEqual(q.numel(), q2.numel())
            self.assertEqual(q2.size(), [15])
            # testing -1
            self.assertEqual(q, q2.reshape([3, -1]))

            a_int = torch.randint(0,
                                  100, [1, 2, 3, 4],
                                  dtype=dtype,
                                  device=device)
            a = torch._make_per_tensor_quantized_tensor(a_int,
                                                        scale=scale,
                                                        zero_point=zero_point)
            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
            c = a.reshape(1, 3, 2, 4)  # does not change tensor layout
            self.assertEqual(b.size(), c.size())
            self.assertEqual(b.q_scale(), c.q_scale())
            self.assertEqual(b.q_zero_point(), c.q_zero_point())
            self.assertNotEqual(b.stride(), c.stride())
            self.assertNotEqual(b.int_repr(), c.int_repr())
            # torch.equal is not supported for the cuda backend
            if device == 'cpu':
                self.assertFalse(torch.equal(b, c))
            else:
                self.assertRaises(RuntimeError, lambda: torch.equal(b, c))

            # we can use reshape for non-contiguous Tensor
            a_int = torch.randint(0,
                                  100, [1, 2, 3, 4],
                                  dtype=dtype,
                                  device=device)
            a = torch._make_per_tensor_quantized_tensor(a_int,
                                                        scale=scale,
                                                        zero_point=zero_point)
            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
            c = b.reshape(1, 4, 2, 3)

    def test_qtensor_unsqueeze(self):
        x = torch.randn((1, 3, 4))
        qx = torch.quantize_per_tensor(x,
                                       scale=1.0,
                                       zero_point=0,
                                       dtype=torch.quint8)
        qy = qx.unsqueeze(2)
        self.assertEqual(qy.size(), (1, 3, 1, 4))
        qy = qy.squeeze(2)
        self.assertEqual(qy.size(), qx.size())

        # Per channel qtensor
        scales = torch.tensor([1.0])
        zero_points = torch.tensor([0])
        qx = torch.quantize_per_channel(x,
                                        scales=scales,
                                        zero_points=zero_points,
                                        dtype=torch.quint8,
                                        axis=0)
        qy = qx.unsqueeze(0)
        self.assertEqual(qy.size(), (1, 1, 3, 4))
        self.assertEqual(qy.q_per_channel_axis(), 1)

        qz = qy.squeeze(0)
        self.assertEqual(qz.size(), x.size())
        self.assertEqual(qz.q_per_channel_axis(), 0)
        with self.assertRaisesRegex(
                RuntimeError,
                "Squeeze is only possible on non-axis dimension for Per-Channel"
        ):
            qz = qy.squeeze(1)

        # squeeze without dim specified
        x = torch.randn((3, 1, 2, 1, 4))
        scales = torch.tensor([1.0, 1.0])
        zero_points = torch.tensor([0, 0])
        qx = torch.quantize_per_channel(x,
                                        scales=scales,
                                        zero_points=zero_points,
                                        dtype=torch.quint8,
                                        axis=2)
        qz = qx.squeeze()
        self.assertEqual(qz.size(), (3, 2, 4))
        self.assertEqual(qz.q_per_channel_axis(), 1)
        with self.assertRaisesRegex(
                RuntimeError,
                "Squeeze is only possible on non-axis dimension for Per-Channel"
        ):
            qz = qy.squeeze()

    def test_repeat(self):
        scale, zero_point, dtype = 1.0, 2, torch.uint8
        for device in get_supported_device_types():
            q_int = torch.randint(0, 100, [3], dtype=dtype, device=device)
            q_int_repeat = q_int.repeat(4, 2)
            q_ref = torch._make_per_tensor_quantized_tensor(
                q_int_repeat, scale=scale, zero_point=zero_point)

            q = torch._make_per_tensor_quantized_tensor(q_int,
                                                        scale=scale,
                                                        zero_point=zero_point)
            q_repeat = q.repeat(4, 2)
            self.assertEqual(q_ref, q_repeat)

    def test_qscheme_pickle(self):
        f = Foo()
        buf = io.BytesIO()
        torch.save(f, buf)

        buf.seek(0)
        f2 = torch.load(buf)

        self.assertEqual(f2.qscheme, torch.per_tensor_symmetric)

    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2,
                                              max_dims=4,
                                              min_side=1,
                                              max_side=10),
                       qparams=hu.qparams()),
           reduce_range=st.booleans())
    def test_choose_qparams(self, X, reduce_range):
        X, (scale, zero_point, torch_type) = X
        X = torch.from_numpy(X)
        X_scale, X_zp = _calculate_dynamic_qparams(X,
                                                   torch.quint8,
                                                   reduce_range=reduce_range)
        qparams = torch._choose_qparams_per_tensor(X, reduce_range)
        np.testing.assert_array_almost_equal(X_scale, qparams[0], decimal=3)
        self.assertEqual(X_zp, qparams[1])

    @unittest.skipIf(not torch.cuda.is_available() or TEST_WITH_ROCM,
                     'CUDA is not available')
    def test_cuda_cpu_implementation_consistency(self):
        numel, zero_point, scale = 100, 2, 0.02
        r = torch.rand(numel, dtype=torch.float32, device='cpu') * 25 - 4
        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
            qr_cpu = torch.quantize_per_tensor(r,
                                               scale,
                                               zero_point,
                                               dtype=dtype)
            qr_cuda = torch.quantize_per_tensor(r.cuda(),
                                                scale,
                                                zero_point,
                                                dtype=dtype)
            # intr repr must be the same
            np.testing.assert_equal(qr_cpu.int_repr().numpy(),
                                    qr_cuda.int_repr().cpu().numpy())
            # dequantized values must be the same
            r_cpu, r_cuda = qr_cpu.dequantize().numpy(), qr_cuda.dequantize(
            ).cpu().numpy()
            np.testing.assert_almost_equal(r_cuda, r_cpu, decimal=5)

    @unittest.skipIf(not torch.cuda.is_available() or TEST_WITH_ROCM,
                     'CUDA is not available')
    def test_cuda_quantization_does_not_pin_memory(self):
        # Context - https://github.com/pytorch/pytorch/issues/41115
        x = torch.randn(3)
        self.assertEqual(x.is_pinned(), False)

        q_int = torch.randint(0,
                              100, [1, 2, 3],
                              device="cuda",
                              dtype=torch.uint8)
        q = torch._make_per_tensor_quantized_tensor(q_int,
                                                    scale=0.1,
                                                    zero_point=0)

        x = torch.randn(3)
        self.assertEqual(x.is_pinned(), False)

    def test_fp16_saturate_op(self):
        x = torch.ones(5, 5, dtype=torch.float32) * 65532
        x[0] = torch.ones(5) * -65532
        # range of fp16 value is [-65504, + 65504]
        ref = torch.ones(5, 5) * 65504
        ref[0] = torch.ones(5) * -65504
        y = torch._saturate_weight_to_fp16(x)
        self.assertEqual(y, ref)
Example #3
0
class TestFakeQuantizePerTensor(TestCase):
    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    def test_forward_per_tensor(self, device, X):
        r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale,
                                                       zero_point, quant_min,
                                                       quant_max)
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        np.testing.assert_allclose(Y,
                                   Y_prime.cpu(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    @unittest.skip("temporarily disable the test")
    def test_backward_per_tensor(self, device, X):
        r"""Tests the backward method.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale,
                                                       zero_point, quant_min,
                                                       quant_max)
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        dout = torch.rand(X.shape, dtype=torch.float).to(device)
        dX = _fake_quantize_per_tensor_affine_grad_reference(
            dout, X, scale, zero_point, quant_min, quant_max)
        Y_prime.backward(dout)
        np.testing.assert_allclose(dX.cpu(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    # https://github.com/pytorch/pytorch/issues/30604
    @unittest.skip("temporarily disable the test")
    def test_numerical_consistency_per_tensor(self, device, X):
        r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        # quantize_per_tensor and dequantize are only implemented in CPU
        Y = torch.dequantize(
            torch.quantize_per_tensor(X.cpu(), scale, zero_point, torch_type))
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        np.testing.assert_allclose(Y,
                                   Y_prime.cpu(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(
        device=st.sampled_from(
            ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
        X=hu.tensor(shapes=hu.array_shapes(
            1,
            5,
        ),
                    qparams=hu.qparams(dtypes=[torch.quint8])),
    )
    def test_fq_module(self, device, X):
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        fq_module = torch.quantization.default_fake_quant().to(device)
        Y_prime = fq_module(X)
        assert fq_module.scale is not None
        assert fq_module.zero_point is not None
        Y = _fake_quantize_per_tensor_affine_reference(X, fq_module.scale,
                                                       fq_module.zero_point,
                                                       quant_min, quant_max)
        np.testing.assert_allclose(Y.cpu().detach().numpy(),
                                   Y_prime.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

        # Test backward
        dout = torch.rand(X.shape, dtype=torch.float, device=device)
        Y_prime.backward(dout)
        dX = _fake_quantize_per_tensor_affine_grad_reference(
            dout, X, fq_module.scale, fq_module.zero_point, quant_min,
            quant_max)
        np.testing.assert_allclose(dX.cpu().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    def test_fq_serializable(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
        y_ref = fq_module(X)
        state_dict = fq_module.state_dict()
        self.assertEqual(state_dict['scale'], 0.094488)
        self.assertEqual(state_dict['zero_point'], 53)
        b = io.BytesIO()
        torch.save(state_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        loaded_fq_module = FakeQuantize(observer, quant_min, quant_max)
        loaded_fq_module.load_state_dict(loaded_dict)
        for key in state_dict:
            self.assertEqual(state_dict[key],
                             loaded_fq_module.state_dict()[key])

        self.assertEqual(loaded_fq_module.calculate_qparams(),
                         fq_module.calculate_qparams())

    def test_fake_quant_control(self):
        torch.manual_seed(42)
        X = torch.rand(20, 10, dtype=torch.float32)
        fq_module = torch.quantization.default_fake_quant()
        # Output of fake quant is not identical to input
        Y = fq_module(X)
        self.assertNotEqual(Y, X)
        torch.quantization.disable_fake_quant(fq_module)
        X = torch.rand(20, 10, dtype=torch.float32)
        Y = fq_module(X)
        # Fake quant is disabled,output is identical to input
        self.assertEqual(Y, X)
        scale = fq_module.scale
        zero_point = fq_module.zero_point
        torch.quantization.disable_observer(fq_module)
        torch.quantization.enable_fake_quant(fq_module)
        X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0
        Y = fq_module(X)
        self.assertNotEqual(Y, X)
        # Observer is disabled, scale and zero-point do not change
        self.assertEqual(fq_module.scale, scale)
        self.assertEqual(fq_module.zero_point, zero_point)
        torch.quantization.enable_observer(fq_module)
        Y = fq_module(X)
        self.assertNotEqual(Y, X)
        # Observer is enabled, scale and zero-point are different
        self.assertNotEqual(fq_module.scale, scale)
        self.assertNotEqual(fq_module.zero_point, zero_point)
Example #4
0
class TestFakeQuantizePerChannel(TestCase):
    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.per_channel_tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                                   qparams=hu.qparams(dtypes=torch.quint8)))
    def test_forward_per_channel(self, device, X):
        r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        scale = to_tensor(scale, device)
        zero_point = torch.tensor(zero_point).to(dtype=torch.int64,
                                                 device=device)
        Y = _fake_quantize_per_channel_affine_reference(
            X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
        Y_prime = torch.fake_quantize_per_channel_affine(
            X, scale, zero_point, axis, quant_min, quant_max)
        np.testing.assert_allclose(Y,
                                   Y_prime.cpu(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.per_channel_tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                                   qparams=hu.qparams(dtypes=torch.quint8)))
    def test_backward_per_channel(self, device, X):
        r"""Tests the backward method.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        scale = to_tensor(scale, device)
        zero_point = torch.tensor(zero_point).to(dtype=torch.int64,
                                                 device=device)
        X.requires_grad_()
        Y_prime = torch.fake_quantize_per_channel_affine(
            X, scale, zero_point, axis, quant_min, quant_max)
        dout = torch.rand(X.shape, dtype=torch.float).to(device)
        dX = _fake_quantize_per_channel_affine_grad_reference(
            dout, X, scale, zero_point, axis, quant_min, quant_max)
        Y_prime.backward(dout)
        np.testing.assert_allclose(dX.cpu().detach().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.per_channel_tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                                   qparams=hu.qparams(dtypes=torch.quint8)))
    @unittest.skip("temporarily disable the test")
    def test_numerical_consistency_per_channel(self, device, X):
        r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        scale = to_tensor(scale, device)
        zero_point = torch.tensor(zero_point).to(dtype=torch.int64,
                                                 device=device)
        # quantize_linear and dequantize are only implemented in CPU
        Y = torch.dequantize(
            torch.quantize_per_channel(X.cpu(), scale.cpu(), zero_point.cpu(),
                                       axis, torch_type))
        Y_prime = torch.fake_quantize_per_channel_affine(
            X, scale, zero_point, axis, quant_min, quant_max)
        np.testing.assert_allclose(Y,
                                   Y_prime.cpu(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.per_channel_tensor(shapes=hu.array_shapes(
               2,
               5,
           ),
                                   qparams=hu.qparams(dtypes=torch.qint8)))
    def test_fq_module(self, device, X):
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        fq_module = FakeQuantize(default_per_channel_weight_observer,
                                 quant_min,
                                 quant_max,
                                 ch_axis=axis).to(device)
        Y_prime = fq_module(X)
        assert fq_module.scale is not None
        assert fq_module.zero_point is not None
        Y = _fake_quantize_per_channel_affine_reference(
            X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(Y.cpu().detach().numpy(),
                                   Y_prime.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

        # Test backward
        dout = torch.rand(X.shape, dtype=torch.float, device=device)
        Y_prime.backward(dout)
        dX = _fake_quantize_per_channel_affine_grad_reference(
            dout, X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(dX.cpu().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    def test_fq_serializable(self):
        observer = default_per_channel_weight_observer
        quant_min = -128
        quant_max = 127
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        X = torch.tensor(
            [[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]],
            dtype=torch.float32)
        y_ref = fq_module(X)
        state_dict = fq_module.state_dict()
        self.assertEqual(state_dict['scale'], [0.054902, 0.078431])
        self.assertEqual(state_dict['zero_point'], [0, 0])
        b = io.BytesIO()
        torch.save(state_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in state_dict:
            self.assertEqual(state_dict[key], loaded_dict[key])
Example #5
0
class TestQuantizedTensor(TestCase):
    def test_qtensor(self):
        num_elements = 10
        scale = 1.0
        zero_point = 2
        for device in get_supported_device_types():
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                r = torch.ones(num_elements, dtype=torch.float, device=device)
                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
                self.assertEqual(qr.q_scale(), scale)
                self.assertEqual(qr.q_zero_point(), zero_point)
                self.assertTrue(qr.is_quantized)
                self.assertFalse(r.is_quantized)
                self.assertEqual(qr.qscheme(), torch.per_tensor_affine)
                self.assertTrue(isinstance(qr.qscheme(), torch.qscheme))
                # slicing and int_repr
                int_repr = qr.int_repr()
                for num in int_repr:
                    self.assertEqual(num, 3)
                for num in qr[2:].int_repr():
                    self.assertEqual(num, 3)
                # dequantize
                rqr = qr.dequantize()
                for i in range(num_elements):
                    self.assertEqual(r[i], rqr[i])
                # we can also print a qtensor
                empty_r = torch.ones((0, 1), dtype=torch.float, device=device)
                empty_qr = torch.quantize_per_tensor(empty_r, scale, zero_point, dtype)

                device_msg = "" if device == 'cpu' else "device='" + device + ":0', "
                dtype_msg = str(dtype) + ", "
                self.assertEqual(' '.join(str(empty_qr).split()),
                                 "tensor([], " + device_msg + "size=(0, 1), dtype=" + dtype_msg +
                                 "quantization_scheme=torch.per_tensor_affine, " +
                                 "scale=1.0, zero_point=2)")

    def test_qtensor_float_assignment(self):
        # Scalar Tensor
        # item
        scale = 1.0
        zero_point = 2
        r = torch.ones(1, dtype=torch.float)
        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
            self.assertEqual(qr.item(), 1)
            self.assertEqual(qr[0].item(), 1)
            # assignment
            self.assertTrue(qr[0].is_quantized)
            qr[0] = 11.3  # float assignment
            self.assertEqual(qr.item(), 11)
            x = torch.ones(1, dtype=torch.float) * 15.3
            # Copying from a float Tensor
            qr[:] = x
            self.assertEqual(qr.item(), 15)

            dtype_msg = str(dtype) + ", "
            self.assertEqual(' '.join(str(qr).split()),
                             "tensor([15.], size=(1,), dtype=" + dtype_msg +
                             "quantization_scheme=torch.per_tensor_affine, " +
                             "scale=1.0, zero_point=2)")

    def test_qtensor_quant_dequant(self):
        scale = 0.02
        zero_point = 2
        for device in get_supported_device_types():
            r = torch.rand(3, 2, dtype=torch.float, device=device) * 4 - 2
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
                rqr = qr.dequantize()
                self.assertTrue(np.allclose(r.cpu().numpy(), rqr.cpu().numpy(), atol=2 / scale))

    # legacy constructor/new doesn't support qtensors
    def test_qtensor_legacy_new_failure(self):
        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
        scale = 0.02
        zero_point = 2
        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.quint8)
        self.assertRaises(RuntimeError, lambda: qr.new(device='cpu'))
        self.assertRaises(RuntimeError, lambda: qr.new(r.storage()))
        self.assertRaises(RuntimeError, lambda: qr.new(r))
        self.assertRaises(RuntimeError, lambda: qr.new(torch.Size([2, 3])))
        self.assertRaises(RuntimeError, lambda: qr.new([6]))

    def test_per_channel_qtensor_creation(self):
        numel = 10
        ch_axis = 0
        scales = torch.rand(numel)
        zero_points = torch.randint(0, 10, size=(numel,))
        for dtype in [torch.qint8, torch.quint8]:
            q = torch._empty_per_channel_affine_quantized(
                [numel], scales=scales, zero_points=zero_points, axis=ch_axis, dtype=dtype)
            self.assertEqual(scales, q.q_per_channel_scales())
            self.assertEqual(zero_points, q.q_per_channel_zero_points())
            self.assertEqual(ch_axis, q.q_per_channel_axis())

        # create Tensor from uint8_t Tensor, scales and zero_points
        int_tensor = torch.randint(0, 100, size=(numel,), dtype=torch.uint8)
        q = torch._make_per_channel_quantized_tensor(int_tensor, scales, zero_points, ch_axis)
        self.assertEqual(int_tensor, q.int_repr())
        self.assertEqual(scales, q.q_per_channel_scales())
        self.assertEqual(zero_points, q.q_per_channel_zero_points())
        self.assertEqual(ch_axis, q.q_per_channel_axis())

    def test_qtensor_creation(self):
        scale = 0.5
        zero_point = 10
        numel = 10
        for device in get_supported_device_types():
            q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point,
                                              device=device, dtype=torch.quint8)
            self.assertEqual(scale, q.q_scale())
            self.assertEqual(zero_point, q.q_zero_point())

            # create Tensor from uint8_t Tensor, scale and zero_point
            int_tensor = torch.randint(0, 100, size=(10,), device=device, dtype=torch.uint8)
            q = torch._make_per_tensor_quantized_tensor(int_tensor, scale, zero_point)
            self.assertEqual(int_tensor, q.int_repr())
            self.assertEqual(scale, q.q_scale())
            self.assertEqual(zero_point, q.q_zero_point())

            # create via empty_like
            q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point,
                                              device=device, dtype=torch.quint8)
            q_el = torch.empty_like(q)
            self.assertEqual(q.q_scale(), q_el.q_scale())
            self.assertEqual(q.q_zero_point(), q_el.q_zero_point())
            self.assertEqual(q.dtype, q_el.dtype)

            # create via empty_like but change the dtype (currently not supported)
            with self.assertRaises(RuntimeError):
                torch.empty_like(q, dtype=torch.qint8)

    def test_qtensor_dtypes(self):
        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
        scale = 0.2
        zero_point = 2
        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.qint8)
        rqr = qr.dequantize()
        self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.quint8)
        rqr = qr.dequantize()
        self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.qint32)
        rqr = qr.dequantize()
        self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))

    def test_qtensor_quantize_per_channel(self):
        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
        scales = torch.tensor([0.2, 0.03], dtype=torch.double)
        zero_points = torch.tensor([5, 10], dtype=torch.long)
        axis = 1

        def quantize_c(data, scales, zero_points):
            res = torch.empty((3, 2))
            quant_min, quant_max = 0, 255
            for i in range(3):
                for j in range(2):
                    res[i][j] = np.clip(np.round(data[i][j] / scales[j]) + zero_points[j], quant_min, quant_max)
            return res
        qr = torch.quantize_per_channel(r, scales, zero_points, axis, torch.quint8)
        rqr = qr.dequantize()
        self.assertTrue(np.allclose(qr.int_repr(), quantize_c(r, scales, zero_points)))
        self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy())))

    def test_qtensor_permute(self):
        scale = 0.02
        zero_point = 1
        for device in get_supported_device_types():
            r = torch.rand(10, 30, 2, 2, device=device, dtype=torch.float) * 4 - 2
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
                qr = qr.transpose(0, 1)
                rqr = qr.dequantize()
                # compare transpose + dequantized result with orignal transposed result
                self.assertTrue(np.allclose(r.cpu().numpy().transpose([1, 0, 2, 3]), rqr.cpu().numpy(), atol=2 / scale))

                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
                qr1 = qr.permute([1, 0, 2, 3])
                qr2 = qr.transpose(0, 1)
                # compare int representation after transformations
                self.assertEqual(qr1.int_repr(), qr2.int_repr())
                self.assertEqual(qr1.q_scale(), qr2.q_scale())
                self.assertEqual(qr1.q_zero_point(), qr2.q_zero_point())
                # compare dequantized result
                self.assertEqual(qr1.dequantize(), qr2.dequantize())
                # compare permuted + dequantized result with original transposed result
                self.assertTrue(np.allclose(qr2.dequantize().cpu().numpy(),
                                            r.cpu().numpy().transpose([1, 0, 2, 3]), atol=2 / scale))
                # make permuted result contiguous
                self.assertEqual(qr2.contiguous().int_repr(), qr2.int_repr())

                # change memory format
                qlast = qr.contiguous(memory_format=torch.channels_last)
                self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride()))))
                self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride()))))
                self.assertEqual(qr.int_repr(), qlast.int_repr())
                self.assertEqual(qr.q_scale(), qlast.q_scale())
                self.assertEqual(qr.q_zero_point(), qlast.q_zero_point())
                self.assertEqual(qlast.dequantize(), qr.dequantize())

                # permuting larger tensors
                x = torch.randn(64, 64, device=device)
                qx = torch.quantize_per_tensor(x, 1.0, 0, dtype)
                # should work
                qx.permute([1, 0])

    def test_qtensor_per_channel_permute(self):
        r = torch.rand(20, 10, 2, 2, dtype=torch.float) * 4 - 2
        dtype = torch.qint8
        scales = torch.rand(10) * 0.02 + 0.01
        zero_points = torch.round(torch.rand(10) * 2 - 1).to(torch.long)
        qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)

        # we can't reorder the axis
        with self.assertRaises(RuntimeError):
            qr.transpose(0, 1)

        # but we can change memory format
        qlast = qr.contiguous(memory_format=torch.channels_last)
        self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride()))))
        self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride()))))
        self.assertEqual(qr.int_repr(), qlast.int_repr())
        self.assertEqual(scales, qlast.q_per_channel_scales())
        self.assertEqual(zero_points, qlast.q_per_channel_zero_points())
        self.assertEqual(1, qlast.q_per_channel_axis())
        self.assertEqual(qlast.dequantize(), qr.dequantize())

    def test_qtensor_load_save(self):
        scale = 0.2
        zero_point = 10
        # storage is not accessible on the cuda right now
        device = "cpu"
        r = torch.rand(15, 2, dtype=torch.float32, device=device) * 2
        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
            qrv = qr[:, 1]
            with tempfile.NamedTemporaryFile() as f:
                # Serializing and Deserializing Tensor
                torch.save((qr, qrv), f)
                f.seek(0)
                qr2, qrv2 = torch.load(f)
                self.assertEqual(qr, qr2)
                self.assertEqual(qrv, qrv2)
                self.assertEqual(qr2.storage().data_ptr(), qrv2.storage().data_ptr())

    def test_qtensor_per_channel_load_save(self):
        r = torch.rand(20, 10, dtype=torch.float) * 4 - 2
        scales = torch.rand(10, dtype=torch.double) * 0.02 + 0.01
        zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long)
        # quint32, cuda is not supported yet
        for dtype in [torch.quint8, torch.qint8]:
            qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)
            with tempfile.NamedTemporaryFile() as f:
                # Serializing and Deserializing Tensor
                torch.save(qr, f)
                f.seek(0)
                qr2 = torch.load(f)
                self.assertEqual(qr, qr2)

    def test_qtensor_copy(self):
        scale = 0.5
        zero_point = 10
        numel = 10
        for device in get_supported_device_types():
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                # copy from same scale and zero_point
                q = torch._empty_affine_quantized([numel], scale=scale,
                                                  zero_point=zero_point, device=device, dtype=dtype)
                q2 = torch._empty_affine_quantized([numel], scale=scale,
                                                   zero_point=zero_point, device=device, dtype=dtype)
                q.copy_(q2)
                self.assertEqual(q.int_repr(), q2.int_repr())
                self.assertEqual(q.q_scale(), q2.q_scale())
                self.assertEqual(q.q_zero_point(), q2.q_zero_point())
                # copying from different scale and zero_point
                scale = 3.2
                zero_point = 5
                q = torch._empty_affine_quantized([numel], scale=scale,
                                                  zero_point=zero_point, device=device, dtype=dtype)
                # check original scale and zero_points are set correctly
                self.assertEqual(q.q_scale(), scale)
                self.assertEqual(q.q_zero_point(), zero_point)
                q.copy_(q2)
                # check scale and zero_points has been copied
                self.assertEqual(q, q2)
                # can't copy from quantized tensor to non-quantized tensor
                r = torch.empty([numel], dtype=torch.float)
                q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
                with self.assertRaisesRegex(RuntimeError, "please use dequantize"):
                    r.copy_(q)

    def test_torch_qtensor_deepcopy(self):
        # cuda is not supported yet
        device = "cpu"
        q_int = torch.randint(0, 100, [3, 5], device=device, dtype=torch.uint8)
        scale, zero_point = 2.0, 3
        q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
        qc = deepcopy(q)
        self.assertEqual(qc, q)

    def test_qtensor_clone(self):
        numel = 10
        scale = 0.5
        zero_point = 10
        for device in get_supported_device_types():
            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
                q2 = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point,
                                                   device=device, dtype=dtype)
                q = q2.clone()
                # Check to make sure the scale and zero_point has been copied.
                self.assertEqual(q, q2)

    def test_qtensor_view(self):
        scale, zero_point, dtype = 1.0, 2, torch.uint8
        for device in get_supported_device_types():
            q_int = torch.randint(0, 100, [1, 2, 3], device=device, dtype=dtype)
            q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
            q2 = q.view(1, 3, 2)
            self.assertEqual(q.numel(), q2.numel())
            # testing -1
            self.assertEqual(q, q2.view(1, -1, 3))

            a_int = torch.randint(0, 100, [1, 2, 3, 4], device=device, dtype=dtype)
            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
            c = a.view(1, 3, 2, 4)  # does not change tensor layout in memory
            self.assertEqual(b.size(), c.size())
            self.assertEqual(b.q_scale(), c.q_scale())
            self.assertEqual(b.q_zero_point(), c.q_zero_point())
            self.assertNotEqual(b.stride(), c.stride())
            # size is the same but the underlying data is different
            self.assertNotEqual(b.int_repr(), c.int_repr())
            # torch.equal is not supported for the cuda backend
            if device == 'cpu':
                self.assertFalse(torch.equal(b, c))
            else:
                self.assertRaises(RuntimeError, lambda: torch.equal(b, c))

            # a case can't view non-contiguos Tensor
            a_int = torch.randint(0, 100, [1, 2, 3, 4], device=device, dtype=dtype)
            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
            err_str = "view size is not compatible with input tensor's size and stride*"
            with self.assertRaisesRegex(RuntimeError, err_str):
                b.view(1, 4, 2, 3)
            # view on contiguous tensor is fine
            b.contiguous().view(1, 4, 2, 3)

    def test_qtensor_resize(self):
        scale, zero_point, dtype = 1.0, 2, torch.uint8
        sizes1 = [1, 2, 3, 4]
        sizes2 = [1 * 2, 3 * 4]
        sizes3 = [1, 2 * 3, 4]
        sizes4 = [1 * 2 * 3 * 4]
        sizes5 = [1, 2, 1, 3, 1, 4]

        q1_int = torch.randint(0, 100, sizes1, dtype=dtype)
        q1 = torch._make_per_tensor_quantized_tensor(q1_int, scale=scale, zero_point=zero_point)
        q2 = q1.resize(*sizes2)
        q3 = q2.resize(*sizes3)
        q4 = q3.resize(*sizes4)
        q5 = q4.resize(*sizes5)

        self.assertEqual(q1.numel(), q2.numel())
        self.assertEqual(q1.numel(), q3.numel())
        self.assertEqual(q1.numel(), q4.numel())
        self.assertEqual(q1.numel(), q5.numel())

        # Compare original and post-transpose
        a_int = torch.randint(0, 100, sizes1, dtype=dtype)
        a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
        b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
        c = b.resize(*sizes1)  # Change the sizes back to the original

        self.assertEqual(a.size(), c.size())
        self.assertEqual(b.q_scale(), c.q_scale())
        self.assertEqual(b.q_zero_point(), c.q_zero_point())
        self.assertNotEqual(b.stride(), c.stride())
        # size is the same but the underlying data is different
        self.assertNotEqual(b.int_repr(), c.int_repr())
        self.assertFalse(torch.equal(b, c))

        # Throws an error if numel is wrong
        q1_int = torch.randint(0, 100, sizes1, dtype=dtype)
        q1 = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
        err_str = "requested resize to*"
        with self.assertRaisesRegex(RuntimeError, err_str):
            q2 = q1.resize(*sizes1[:-1])
        # resize on both contiguous and non-contiguous tensor should be fine
        q3 = q1.resize(*sizes2)
        q4 = q1.contiguous().resize(*sizes2)

    def test_qtensor_reshape(self):
        scale, zero_point, dtype = 1.0, 2, torch.uint8
        for device in get_supported_device_types():
            q_int = torch.randint(0, 100, [3, 5], dtype=dtype, device=device)
            q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
            q2 = q.reshape([15])
            self.assertEqual(q.numel(), q2.numel())
            self.assertEqual(q2.size(), [15])
            # testing -1
            self.assertEqual(q, q2.reshape([3, -1]))

            a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype, device=device)
            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
            c = a.reshape(1, 3, 2, 4)  # does not change tensor layout
            self.assertEqual(b.size(), c.size())
            self.assertEqual(b.q_scale(), c.q_scale())
            self.assertEqual(b.q_zero_point(), c.q_zero_point())
            self.assertNotEqual(b.stride(), c.stride())
            self.assertNotEqual(b.int_repr(), c.int_repr())
            # torch.equal is not supported for the cuda backend
            if device == 'cpu':
                self.assertFalse(torch.equal(b, c))
            else:
                self.assertRaises(RuntimeError, lambda: torch.equal(b, c))

            # we can use reshape for non-contiguous Tensor
            a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype, device=device)
            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
            c = b.reshape(1, 4, 2, 3)

    def test_qscheme_pickle(self):
        f = Foo()
        buf = io.BytesIO()
        torch.save(f, buf)

        buf.seek(0)
        f2 = torch.load(buf)

        self.assertEqual(f2.qscheme, torch.per_tensor_symmetric)

    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=4,
                                              min_side=1, max_side=10),
                       qparams=hu.qparams()),
           reduce_range=st.booleans()
           )
    def test_choose_qparams(self, X, reduce_range):
        X, (scale, zero_point, torch_type) = X
        X = torch.from_numpy(X)
        X_scale, X_zp = _calculate_dynamic_qparams(X, torch.quint8, reduce_range=reduce_range)
        qparams = torch._choose_qparams_per_tensor(X, reduce_range)
        np.testing.assert_array_almost_equal(X_scale, qparams[0], decimal=3)
        self.assertEqual(X_zp, qparams[1])

    @unittest.skipIf(not torch.cuda.is_available() or TEST_WITH_ROCM, 'CUDA is not available')
    def test_cuda_cpu_implementation_consistency(self):
        numel, zero_point, scale = 100, 2, 0.02
        r = torch.rand(numel, dtype=torch.float32, device='cpu') * 25 - 4
        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
            qr_cpu = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
            qr_cuda = torch.quantize_per_tensor(r.cuda(), scale, zero_point, dtype=dtype)
            # intr repr must be the same
            np.testing.assert_equal(qr_cpu.int_repr().numpy(), qr_cuda.int_repr().cpu().numpy())
            # dequantized values must be the same
            r_cpu, r_cuda = qr_cpu.dequantize().numpy(), qr_cuda.dequantize().cpu().numpy()
            np.testing.assert_almost_equal(r_cuda, r_cpu, decimal=5)
Example #6
0
class TestObserver(QuantizationTestCase):
    @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
           qscheme=st.sampled_from(
               (torch.per_tensor_affine, torch.per_tensor_symmetric)),
           reduce_range=st.booleans())
    def test_per_tensor_observers(self, qdtype, qscheme, reduce_range):
        # reduce_range cannot be true for symmetric quantization with uint8
        if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric:
            reduce_range = False
        ObserverList = [
            MinMaxObserver(dtype=qdtype,
                           qscheme=qscheme,
                           reduce_range=reduce_range),
            MovingAverageMinMaxObserver(averaging_constant=0.5,
                                        dtype=qdtype,
                                        qscheme=qscheme,
                                        reduce_range=reduce_range)
        ]
        for myobs in ObserverList:
            # Calculate Qparams should return with a warning for observers with no data
            qparams = myobs.calculate_qparams()
            if type(myobs) == MinMaxObserver:
                x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
                y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
            else:
                # Moving average of min/max for x and y matches that of
                # extreme values for x/y used for minmax observer
                x = torch.tensor([0.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
                y = torch.tensor([2.0, 5.0, 5.0, 6.0, 7.0, 10.0])

            result = myobs(x)
            result = myobs(y)
            self.assertEqual(result, y)
            self.assertEqual(myobs.min_val, 1.0)
            self.assertEqual(myobs.max_val, 8.0)
            qparams = myobs.calculate_qparams()
            if reduce_range:
                if qscheme == torch.per_tensor_symmetric:
                    ref_scale = 0.062745 * 255 / 127
                    ref_zero_point = 0 if qdtype is torch.qint8 else 128
                else:
                    ref_scale = 0.0313725 * 255 / 127
                    ref_zero_point = -64 if qdtype is torch.qint8 else 0
            else:
                if qscheme == torch.per_tensor_symmetric:
                    ref_scale = 0.062745
                    ref_zero_point = 0 if qdtype is torch.qint8 else 128
                else:
                    ref_scale = 0.0313725
                    ref_zero_point = -128 if qdtype is torch.qint8 else 0
            self.assertEqual(qparams[1].item(), ref_zero_point)
            self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
            state_dict = myobs.state_dict()
            b = io.BytesIO()
            torch.save(state_dict, b)
            b.seek(0)
            loaded_dict = torch.load(b)
            for key in state_dict:
                self.assertEqual(state_dict[key], loaded_dict[key])
            loaded_obs = MinMaxObserver(dtype=qdtype,
                                        qscheme=qscheme,
                                        reduce_range=reduce_range)
            loaded_obs.load_state_dict(loaded_dict)
            loaded_qparams = loaded_obs.calculate_qparams()
            self.assertEqual(myobs.min_val, loaded_obs.min_val)
            self.assertEqual(myobs.max_val, loaded_obs.max_val)
            self.assertEqual(myobs.calculate_qparams(),
                             loaded_obs.calculate_qparams())

    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2,
                                              max_dims=4,
                                              min_side=1,
                                              max_side=10),
                       qparams=hu.qparams()),
           reduce_range=st.booleans())
    def test_per_tensor_dynamic_quant_observers(self, X, reduce_range):

        X, (scale, zero_point, torch_type) = X
        x = torch.from_numpy(X)

        obs = MinMaxDynamicQuantObserver(dtype=torch.quint8,
                                         reduce_range=reduce_range)

        result = obs(x)
        qparams = obs.calculate_qparams()
        ref = torch._choose_qparams_per_tensor(x, reduce_range)

        self.assertEqual(ref[0], qparams[0])
        self.assertEqual(ref[1], qparams[1])

    @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
           qscheme=st.sampled_from(
               (torch.per_channel_affine, torch.per_channel_symmetric)),
           ch_axis=st.sampled_from((0, 1, 2, 3)),
           reduce_range=st.booleans())
    def test_per_channel_observers(self, qdtype, qscheme, ch_axis,
                                   reduce_range):
        # reduce_range cannot be true for symmetric quantization with uint8
        if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
            reduce_range = False
        ObserverList = [
            PerChannelMinMaxObserver(reduce_range=reduce_range,
                                     ch_axis=ch_axis,
                                     dtype=qdtype,
                                     qscheme=qscheme),
            MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5,
                                                  reduce_range=reduce_range,
                                                  ch_axis=ch_axis,
                                                  dtype=qdtype,
                                                  qscheme=qscheme)
        ]

        for myobs in ObserverList:
            # Calculate qparams should work for empty observers
            qparams = myobs.calculate_qparams()
            x = torch.tensor([
                [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]],
                [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
            ])
            if type(myobs) == MovingAveragePerChannelMinMaxObserver:
                # Scaling the input tensor to model change in min/max values
                # across batches
                result = myobs(0.5 * x)
                result = myobs(1.5 * x)
                self.assertEqual(result, 1.5 * x)
            else:
                result = myobs(x)
                self.assertEqual(result, x)

            qparams = myobs.calculate_qparams()
            ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0],
                            [-4.0, -3.0]]
            ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]]
            per_channel_symmetric_ref_scales = [
                [0.04705882, 0.06274509],
                [0.03921569, 0.0627451],
                [0.04705882, 0.0627451],
                [0.05490196, 0.0627451],
            ]
            per_channel_affine_ref_scales = [
                [0.02352941, 0.04705882],
                [0.03529412, 0.03137255],
                [0.03921569, 0.03137255],
                [0.04313726, 0.04313726],
            ]
            per_channel_affine_qint8_zp = [
                [-128, -43],
                [-15, -128],
                [-26, -128],
                [-35, -58],
            ]
            per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0],
                                            [93, 70]]

            self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis])
            self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis])
            if qscheme == torch.per_channel_symmetric:
                ref_scales = per_channel_symmetric_ref_scales[ch_axis]
                ref_zero_points = [0, 0
                                   ] if qdtype is torch.qint8 else [128, 128]
            else:
                ref_scales = per_channel_affine_ref_scales[ch_axis]
                ref_zero_points = (per_channel_affine_qint8_zp[ch_axis]
                                   if qdtype is torch.qint8 else
                                   per_channel_affine_quint8_zp[ch_axis])

            if reduce_range:
                ref_scales = [s * 255 / 127 for s in ref_scales]
                ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]

            self.assertTrue(
                torch.allclose(
                    qparams[0], torch.tensor(ref_scales,
                                             dtype=qparams[0].dtype)))
            self.assertTrue(
                torch.allclose(
                    qparams[1],
                    torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))

            # Test for serializability
            state_dict = myobs.state_dict()
            b = io.BytesIO()
            torch.save(state_dict, b)
            b.seek(0)
            loaded_dict = torch.load(b)
            for key in state_dict:
                self.assertEqual(state_dict[key], loaded_dict[key])
            loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range,
                                                  ch_axis=ch_axis,
                                                  dtype=qdtype,
                                                  qscheme=qscheme)
            loaded_obs.load_state_dict(loaded_dict)
            loaded_qparams = loaded_obs.calculate_qparams()
            self.assertEqual(myobs.min_vals, loaded_obs.min_vals)
            self.assertEqual(myobs.max_vals, loaded_obs.max_vals)
            self.assertEqual(myobs.calculate_qparams(),
                             loaded_obs.calculate_qparams())

    def test_observer_scriptable(self):
        obs_list = [
            MinMaxObserver(),
            MovingAverageMinMaxObserver(),
            MinMaxDynamicQuantObserver()
        ]
        for obs in obs_list:
            scripted = torch.jit.script(obs)

            x = torch.rand(3, 4)
            obs(x)
            scripted(x)
            self.assertEqual(obs.calculate_qparams(),
                             scripted.calculate_qparams())

            buf = io.BytesIO()
            torch.jit.save(scripted, buf)
            buf.seek(0)
            loaded = torch.jit.load(buf)
            self.assertEqual(obs.calculate_qparams(),
                             loaded.calculate_qparams())

    # TODO: move this to test_quantize.py
    def test_no_qconfig_propagation(self):
        model = ModelWithNoQconfigPropagation()
        model.qconfig = torch.quantization.default_qconfig

        model = prepare(model)
        self.assertTrue(hasattr(model.fc1, 'qconfig'),
                        "QConfig is expected to propagate")
        self.assertFalse(hasattr(model.no_quant_module, 'qconfig'),
                         "QConfig is expected to NOT propagate")
class TestFakeQuantizePerTensor(TestCase):
    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    def test_forward_per_tensor(self, device, X):
        r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale,
                                                       zero_point, quant_min,
                                                       quant_max)
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        np.testing.assert_allclose(Y,
                                   Y_prime.cpu(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    @unittest.skip("temporarily disable the test")
    def test_backward_per_tensor(self, device, X):
        r"""Tests the backward method.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale,
                                                       zero_point, quant_min,
                                                       quant_max)
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        dout = torch.rand(X.shape, dtype=torch.float).to(device)
        dX = _fake_quantize_per_tensor_affine_grad_reference(
            dout, X, scale, zero_point, quant_min, quant_max)
        Y_prime.backward(dout)
        np.testing.assert_allclose(dX.cpu(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    # https://github.com/pytorch/pytorch/issues/30604
    @unittest.skip("temporarily disable the test")
    def test_numerical_consistency_per_tensor(self, device, X):
        r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        # quantize_per_tensor and dequantize are only implemented in CPU
        Y = torch.dequantize(
            torch.quantize_per_tensor(X.cpu(), scale, zero_point, torch_type))
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        np.testing.assert_allclose(Y,
                                   Y_prime.cpu(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(
        device=st.sampled_from(
            ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
        X=hu.tensor(shapes=hu.array_shapes(
            1,
            5,
        ),
                    qparams=hu.qparams(dtypes=[torch.quint8])),
    )
    def test_fq_module(self, device, X):
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        fq_module = torch.quantization.default_fake_quant().to(device)
        Y_prime = fq_module(X)
        assert fq_module.scale is not None
        assert fq_module.zero_point is not None
        Y = _fake_quantize_per_tensor_affine_reference(X, fq_module.scale,
                                                       fq_module.zero_point,
                                                       quant_min, quant_max)
        np.testing.assert_allclose(Y.cpu().detach().numpy(),
                                   Y_prime.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

        # Test backward
        dout = torch.rand(X.shape, dtype=torch.float, device=device)
        Y_prime.backward(dout)
        dX = _fake_quantize_per_tensor_affine_grad_reference(
            dout, X, fq_module.scale, fq_module.zero_point, quant_min,
            quant_max)
        np.testing.assert_allclose(dX.cpu().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    def test_fq_serializable(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
        y_ref = fq_module(X)
        state_dict = fq_module.state_dict()
        self.assertEqual(state_dict['scale'], 0.094488)
        self.assertEqual(state_dict['zero_point'], 53)
        b = io.BytesIO()
        torch.save(state_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        loaded_fq_module = FakeQuantize(observer, quant_min, quant_max)
        loaded_fq_module.load_state_dict(loaded_dict)
        for key in state_dict:
            self.assertEqual(state_dict[key],
                             loaded_fq_module.state_dict()[key])

        self.assertEqual(loaded_fq_module.calculate_qparams(),
                         fq_module.calculate_qparams())

    def test_fake_quant_control(self):
        torch.manual_seed(42)
        X = torch.rand(20, 10, dtype=torch.float32)
        fq_module = torch.quantization.default_fake_quant()
        # Output of fake quant is not identical to input
        Y = fq_module(X)
        self.assertNotEqual(Y, X)
        torch.quantization.disable_fake_quant(fq_module)
        X = torch.rand(20, 10, dtype=torch.float32)
        Y = fq_module(X)
        # Fake quant is disabled,output is identical to input
        self.assertEqual(Y, X)

        # Explicit copy at this point in time, because FakeQuant keeps internal
        # state in mutable buffers.
        scale = fq_module.scale.clone().detach()
        zero_point = fq_module.zero_point.clone().detach()

        torch.quantization.disable_observer(fq_module)
        torch.quantization.enable_fake_quant(fq_module)
        X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0
        Y = fq_module(X)
        self.assertNotEqual(Y, X)
        # Observer is disabled, scale and zero-point do not change
        self.assertEqual(fq_module.scale, scale)
        self.assertEqual(fq_module.zero_point, zero_point)
        torch.quantization.enable_observer(fq_module)
        Y = fq_module(X)
        self.assertNotEqual(Y, X)
        # Observer is enabled, scale and zero-point are different
        self.assertNotEqual(fq_module.scale, scale)
        self.assertNotEqual(fq_module.zero_point, zero_point)

    def test_fake_quant_preserves_qparam_shapes_for_activations(self):
        class Model(nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.linear = nn.Linear(4, 4)

            def forward(self, x):
                x = self.linear(x)
                return x

        m = Model()

        m.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        torch.quantization.prepare_qat(m, inplace=True)

        scale_shape_before = m.linear.activation_post_process.scale.shape
        zero_point_shape_before = m.linear.activation_post_process.zero_point.shape

        x = torch.rand(4, 4, 4, 4)
        m(x)
        scale_shape_after = m.linear.activation_post_process.scale.shape
        zero_point_shape_after = m.linear.activation_post_process.zero_point.shape
        self.assertEqual(scale_shape_before,
                         scale_shape_after,
                         msg="FakeQuant scale shape must stay consistent")
        self.assertEqual(zero_point_shape_before,
                         zero_point_shape_after,
                         msg="FakeQuant zero_point shape must stay consistent")

    def fake_quant_scriptable(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        scripted_module = torch.jit.script(fq_module)

        X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)

        fq_module(X)
        scripted_module(X)
        self.assertEqual(fq_module.calculate_qparams(),
                         scripted_module.calculate_qparams())

        buf = io.BytesIO()
        torch.jit.save(scripted_module, buf)
        buf.seek(0)
        loaded_module = torch.jit.load(buf)
        self.assertEqual(fq_module.calculate_qparams(),
                         loaded_module.calculate_qparams())
Example #8
0
class TestFakeQuantizeOps(TestCase):
    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    def test_forward_per_tensor(self, device, X):
        r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale,
                                                       zero_point, quant_min,
                                                       quant_max)
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        np.testing.assert_allclose(Y,
                                   Y_prime.cpu(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    @unittest.skip("temporarily disable the test")
    def test_backward_per_tensor(self, device, X):
        r"""Tests the backward method.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale,
                                                       zero_point, quant_min,
                                                       quant_max)
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        dout = torch.rand_like(X, dtype=torch.float).to(device)
        dX = _fake_quantize_per_tensor_affine_grad_reference(
            dout, X, scale, zero_point, quant_min, quant_max)
        Y_prime.backward(dout)
        np.testing.assert_allclose(dX.cpu(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    def test_forward_backward_per_tensor_with_amp(self):
        net = nn.Sequential(nn.Conv2d(1, 1, 3))
        net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        net_prep = torch.quantization.prepare_qat(net)

        with torch.cuda.amp.autocast():
            x = torch.randn(4, 1, 5, 5)
            out = net_prep(x).sum()
            out.backward()
            self.assertTrue(net_prep[0].weight.grad is not None)

    def test_forward_per_tensor_half_precision_numerics(self):
        scale = .1
        zero = 0
        maxi = 255
        mini = 0

        for i in range(20):
            X1 = torch.randn(5, 5).to(torch.float16)
            Y1 = torch.fake_quantize_per_tensor_affine(X1, scale, zero, mini,
                                                       maxi)
            Y1r = _fake_quantize_per_tensor_affine_reference(
                X1, scale, zero, mini, maxi)
            self.assertTrue(
                torch.allclose(Y1, Y1r, rtol=tolerance, atol=tolerance))

        # to force overflow
        X2 = torch.tensor(2**15 + .01).to(torch.float16)
        Y2 = torch.fake_quantize_per_tensor_affine(X2, scale, zero, mini, maxi)
        Y2r = _fake_quantize_per_tensor_affine_reference(
            X2, scale, zero, mini, maxi)
        self.assertTrue(torch.allclose(Y2, Y2r, rtol=tolerance,
                                       atol=tolerance))

        scale = 10

        # to force underflow
        X3 = torch.tensor(2**-24).to(torch.float16)
        Y3 = torch.fake_quantize_per_tensor_affine(X3, scale, zero, mini, maxi)
        Y3r = _fake_quantize_per_tensor_affine_reference(
            X3, scale, zero, mini, maxi)
        self.assertTrue(torch.allclose(Y3, Y3r, rtol=tolerance,
                                       atol=tolerance))

    def _test_forward_per_tensor_cachemask_impl(self, device):
        float_types = (torch.float32, torch.float16, torch.float64)
        torch_types = (torch.qint8, torch.quint8)
        Xs = (torch.randn(4, 8,
                          device=device), torch.randn(4, 16,
                                                      device=device)[:, ::2])
        for float_type, torch_type, X in itertools.product(
                float_types, torch_types, Xs):
            # pick the scale + zp so that some values get clipped
            X = X.to(float_type)
            obs = torch.quantization.MinMaxObserver(torch_type)
            obs(X * 0.75)
            scale, zero_point = obs.calculate_qparams()
            scale, zero_point = float(scale), int(zero_point)
            quant_min, quant_max = obs._calculate_qmin_qmax()

            Y_test = torch.fake_quantize_per_tensor_affine(
                X, scale, zero_point, quant_min, quant_max)
            Y_ref = _fake_quantize_per_tensor_affine_reference(
                X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
            self.assertTrue(
                torch.allclose(Y_test, Y_ref, rtol=tolerance, atol=tolerance))
            self.assertTrue(Y_test.dtype == float_type)

    def test_forward_per_tensor_cachemask_cpu(self):
        device = torch.device('cpu')
        self._test_forward_per_tensor_cachemask_impl(device)

    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
    def test_forward_per_tensor_cachemask_cuda(self):
        device = torch.device('cuda')
        self._test_forward_per_tensor_cachemask_impl(device)

    def _test_backward_per_tensor_cachemask_impl(self, device):
        float_types = (torch.float32, torch.float16, torch.float64)
        torch_types = (torch.qint8, torch.quint8)
        for float_type, torch_type in itertools.product(
                float_types, torch_types):
            X = torch.randn(4, 8).to(device).to(float_type)
            X.requires_grad_()
            # pick the scale + zp so that some values get clipped
            obs = torch.quantization.MinMaxObserver(torch_type)
            obs(X * 0.75)
            scale, zero_point = obs.calculate_qparams()
            scale, zero_point = float(scale), int(zero_point)
            quant_min, quant_max = obs._calculate_qmin_qmax()

            # forward pass
            Y_test = torch.fake_quantize_per_tensor_affine(
                X, scale, zero_point, quant_min, quant_max)
            Y_ref = _fake_quantize_per_tensor_affine_reference(
                X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
            self.assertTrue(
                torch.allclose(Y_test, Y_ref, rtol=tolerance, atol=tolerance))

            # backward pass
            dout = torch.rand_like(X, dtype=torch.float).to(device)
            dX = _fake_quantize_per_tensor_affine_grad_reference(
                dout, X, scale, zero_point, quant_min, quant_max)
            Y_test.backward(dout)
            self.assertTrue(torch.allclose(dX, X.grad))
            self.assertTrue(X.grad.dtype == float_type)

    def test_backward_per_tensor_cachemask_cpu(self):
        device = torch.device('cpu')
        self._test_backward_per_tensor_cachemask_impl(device)

    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
    def test_backward_per_tensor_cachemask_cuda(self):
        device = torch.device('cuda')
        self._test_backward_per_tensor_cachemask_impl(device)

    def _test_learnable_forward_per_tensor(self, X, device, scale_base,
                                           zero_point_base):
        X_base = torch.tensor(X).to(device)

        for n_bits in (4, 8):
            quant_min, quant_max = 0, 2**n_bits - 1

            X = X_base.clone().float()
            scale_base = scale_base.to(device).float()
            zero_point_base = zero_point_base.to(dtype=torch.int64,
                                                 device=device)
            scale = scale_base.clone()
            zero_point = zero_point_base.clamp(quant_min, quant_max)

            Y = _fake_quantize_per_tensor_affine_reference(
                X, scale, zero_point, quant_min, quant_max).to(device)
            for grad_factor in [0.1, 1.0, 10.0]:
                Y_prime = torch._fake_quantize_learnable_per_tensor_affine(
                    X, scale, zero_point, quant_min, quant_max,
                    grad_factor).to(device)
                self.assertTrue(
                    torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
                    "Expected kernel forward function to have results match the reference forward function"
                )

    @given(X=hu.tensor(shapes=hu.array_shapes(
        1,
        5,
    ),
                       elements=hu.floats(-1e3,
                                          1e3,
                                          allow_nan=False,
                                          allow_infinity=False),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    def test_learnable_forward_per_tensor_cpu(self, X):
        X, (_, _, _) = X
        scale_base = torch.normal(mean=0, std=1, size=(1, )).clamp(1e-4, 100)
        zero_point_base = torch.normal(mean=0, std=128, size=(1, ))
        self._test_learnable_forward_per_tensor(X, 'cpu', scale_base,
                                                zero_point_base)

    @given(X=hu.tensor(shapes=hu.array_shapes(
        1,
        5,
    ),
                       elements=hu.floats(-1e3,
                                          1e3,
                                          allow_nan=False,
                                          allow_infinity=False),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
    def test_learnable_forward_per_tensor_cuda(self, X):
        X, (_, _, _) = X
        scale_base = torch.normal(mean=0, std=1, size=(1, )).clamp(1e-4, 100)
        zero_point_base = torch.normal(mean=0, std=128, size=(1, ))
        self._test_learnable_forward_per_tensor(X, 'cuda', scale_base,
                                                zero_point_base)

    def _test_learnable_backward_per_tensor(self, X, device, scale_base,
                                            zero_point_base):
        r"""Tests the backward method with additional backprop support for scale and zero point.
        """
        X_base = torch.tensor(X).to(device)

        for n_bits in (4, 8):
            quant_min, quant_max = 0, 2**n_bits - 1

            X = X_base.clone().float().to(device)
            X.requires_grad_()
            scale_base = scale_base.to(device)
            zero_point_base = zero_point_base.to(device)
            scale = scale_base.clone()
            scale.requires_grad_()
            zero_point = zero_point_base.clone().clamp(quant_min, quant_max)
            zero_point.requires_grad_()
            for grad_factor in [0.1, 1.0, 10.0]:
                Y_prime = torch._fake_quantize_learnable_per_tensor_affine(
                    X, scale, zero_point, quant_min, quant_max,
                    grad_factor).to(device)
                dout = torch.rand_like(X, dtype=torch.float).to(device)
                dX, dScale, dZeroPoint = _fake_quantize_learnable_per_tensor_affine_grad_reference(
                    dout, X, scale, zero_point, quant_min, quant_max, device)
                Y_prime.backward(dout)

                expected_dX = dX.to(device).detach()
                actual_dX = X.grad.to(device).detach()
                expected_dScale = dScale.to(device).detach()
                actual_dScale = scale.grad.to(device).detach()
                expected_dZeroPoint = dZeroPoint.to(device).detach()
                actual_dZeroPoint = zero_point.grad.to(device).detach()

                self.assertTrue(
                    torch.allclose(expected_dX,
                                   actual_dX,
                                   rtol=tolerance,
                                   atol=tolerance),
                    "Expected dX to match X.grad")
                self.assertTrue(
                    torch.allclose(expected_dScale * grad_factor,
                                   actual_dScale,
                                   rtol=tolerance,
                                   atol=tolerance),
                    "Expected dScale to match scale.grad")
                self.assertTrue(
                    torch.allclose(expected_dZeroPoint * grad_factor,
                                   actual_dZeroPoint,
                                   rtol=tolerance,
                                   atol=tolerance),
                    "Expected dZeroPoint to match zero_point.grad")
                X.grad.data.zero_()
                scale.grad.data.zero_()
                zero_point.grad.data.zero_()

    @given(X=hu.tensor(shapes=hu.array_shapes(
        1,
        5,
    ),
                       elements=hu.floats(-1e3,
                                          1e3,
                                          allow_nan=False,
                                          allow_infinity=False),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    def test_learnable_backward_per_tensor_cpu(self, X):
        torch.random.manual_seed(NP_RANDOM_SEED)
        X, (_, _, _) = X
        scale_base = torch.normal(mean=0, std=1, size=(1, )).clamp(1e-4, 100)
        zero_point_base = torch.normal(mean=0, std=128, size=(1, ))
        self._test_learnable_backward_per_tensor(X, 'cpu', scale_base,
                                                 zero_point_base)

    @given(X=hu.tensor(shapes=hu.array_shapes(
        1,
        5,
    ),
                       elements=hu.floats(-1e3,
                                          1e3,
                                          allow_nan=False,
                                          allow_infinity=False),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
    def test_learnable_backward_per_tensor_cuda(self, X):
        torch.random.manual_seed(NP_RANDOM_SEED)
        X, (_, _, _) = X
        scale_base = torch.normal(mean=0, std=1, size=(1, )).clamp(1e-4, 100)
        zero_point_base = torch.normal(mean=0, std=128, size=(1, ))
        self._test_learnable_backward_per_tensor(X, 'cuda', scale_base,
                                                 zero_point_base)

    @given(
        device=st.sampled_from(
            ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
        X=hu.tensor(shapes=hu.array_shapes(
            1,
            5,
        ),
                    qparams=hu.qparams(dtypes=[torch.quint8])),
    )
    def test_fq_module_per_tensor(self, device, X):
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        fq_module = torch.quantization.default_fake_quant().to(device)
        Y_prime = fq_module(X)
        assert fq_module.scale is not None
        assert fq_module.zero_point is not None
        Y = _fake_quantize_per_tensor_affine_reference(X, fq_module.scale,
                                                       fq_module.zero_point,
                                                       quant_min, quant_max)
        np.testing.assert_allclose(Y.cpu().detach().numpy(),
                                   Y_prime.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

        # Test backward
        dout = torch.rand_like(X, dtype=torch.float, device=device)
        Y_prime.backward(dout)
        dX = _fake_quantize_per_tensor_affine_grad_reference(
            dout, X, fq_module.scale, fq_module.zero_point, quant_min,
            quant_max)
        np.testing.assert_allclose(dX.cpu().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                       qparams=hu.qparams(dtypes=torch.quint8)))
    def test_fixed_qparams_fq_module(self, device, X):
        X, (scale, zero_point, torch_type) = X
        X = to_tensor(X, device)
        fq_module = default_affine_fixed_qparams_fake_quant()
        fixed_scale = fq_module.scale.clone()
        fixed_zero_point = fq_module.zero_point.clone()
        # run fq module and make sure the quantization parameters does not change
        torch.quantization.enable_observer(fq_module)
        fq_module(X)
        self.assertEqual(fixed_scale, fq_module.scale)
        self.assertEqual(fixed_zero_point, fq_module.zero_point)

    def test_fq_serializable_per_tensor(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        for FakeQuantizeClass in [FakeQuantize, _LearnableFakeQuantize]:
            fq_module = FakeQuantizeClass(observer, quant_min, quant_max)
            X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
            y_ref = fq_module(X)
            state_dict = fq_module.state_dict()
            self.assertEqual(state_dict['scale'], 0.094488)
            self.assertEqual(state_dict['zero_point'], 53)
            b = io.BytesIO()
            torch.save(state_dict, b)
            b.seek(0)
            loaded_dict = torch.load(b)
            loaded_fq_module = FakeQuantizeClass(observer, quant_min,
                                                 quant_max)
            loaded_fq_module.load_state_dict(loaded_dict)
            for key in state_dict:
                self.assertEqual(state_dict[key],
                                 loaded_fq_module.state_dict()[key])

            self.assertEqual(loaded_fq_module.calculate_qparams(),
                             fq_module.calculate_qparams())

    def test_fake_quant_control(self):
        for fq_module in [
                torch.quantization.default_fake_quant(),
                _LearnableFakeQuantize.with_args(
                    observer=MovingAverageMinMaxObserver,
                    quant_min=0,
                    quant_max=255,
                    dtype=torch.quint8,
                    qscheme=torch.per_tensor_affine,
                    reduce_range=True)()
        ]:
            torch.manual_seed(42)
            X = torch.rand(20, 10, dtype=torch.float32)
            # Output of fake quant is not identical to input
            Y = fq_module(X)
            self.assertNotEqual(Y, X)
            if type(fq_module) == _LearnableFakeQuantize:
                fq_module.toggle_fake_quant(False)
            else:
                torch.quantization.disable_fake_quant(fq_module)
            X = torch.rand(20, 10, dtype=torch.float32)
            Y = fq_module(X)
            # Fake quant is disabled,output is identical to input
            self.assertEqual(Y, X)

            # Explicit copy at this point in time, because FakeQuant keeps internal
            # state in mutable buffers.
            scale = fq_module.scale.clone().detach()
            zero_point = fq_module.zero_point.clone().detach()

            if type(fq_module) == _LearnableFakeQuantize:
                fq_module.toggle_observer_update(False)
                fq_module.toggle_fake_quant(True)
            else:
                torch.quantization.disable_observer(fq_module)
                torch.quantization.enable_fake_quant(fq_module)
            X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0
            Y = fq_module(X)
            self.assertNotEqual(Y, X)
            # Observer is disabled, scale and zero-point do not change
            self.assertEqual(fq_module.scale, scale)
            self.assertEqual(fq_module.zero_point, zero_point)
            if type(fq_module) == _LearnableFakeQuantize:
                fq_module.toggle_observer_update(True)
            else:
                torch.quantization.enable_observer(fq_module)
            Y = fq_module(X)
            self.assertNotEqual(Y, X)
            # Observer is enabled, scale and zero-point are different
            self.assertNotEqual(fq_module.scale, scale)
            self.assertNotEqual(fq_module.zero_point, zero_point)

    def test_fake_quant_preserves_qparam_shapes_for_activations(self):
        class Model(nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.linear = nn.Linear(4, 4)

            def forward(self, x):
                x = self.linear(x)
                return x

        m = Model()

        m.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        torch.quantization.prepare_qat(m, inplace=True)

        scale_shape_before = m.linear.activation_post_process.scale.shape
        zero_point_shape_before = m.linear.activation_post_process.zero_point.shape

        x = torch.rand(4, 4, 4, 4)
        m(x)
        scale_shape_after = m.linear.activation_post_process.scale.shape
        zero_point_shape_after = m.linear.activation_post_process.zero_point.shape
        self.assertEqual(scale_shape_before,
                         scale_shape_after,
                         msg="FakeQuant scale shape must stay consistent")
        self.assertEqual(zero_point_shape_before,
                         zero_point_shape_after,
                         msg="FakeQuant zero_point shape must stay consistent")

    def fake_quant_scriptable(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        for FakeQuantizeClass in [FakeQuantize, _LearnableFakeQuantize]:
            fq_module = FakeQuantizeClass(observer, quant_min, quant_max)
            scripted_module = torch.jit.script(fq_module)

            X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)

            fq_module(X)
            scripted_module(X)
            self.assertEqual(fq_module.calculate_qparams(),
                             scripted_module.calculate_qparams())

            buf = io.BytesIO()
            torch.jit.save(scripted_module, buf)
            buf.seek(0)
            loaded_module = torch.jit.load(buf)
            self.assertEqual(fq_module.calculate_qparams(),
                             loaded_module.calculate_qparams())

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.per_channel_tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                                   qparams=hu.qparams(dtypes=torch.quint8)))
    def test_forward_per_channel(self, device, X):
        r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        scale = to_tensor(scale, device)
        zero_point = torch.tensor(zero_point).to(dtype=torch.int64,
                                                 device=device)
        Y = _fake_quantize_per_channel_affine_reference(
            X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
        Y_prime = torch.fake_quantize_per_channel_affine(
            X, scale, zero_point, axis, quant_min, quant_max)
        np.testing.assert_allclose(Y,
                                   Y_prime.cpu(),
                                   rtol=tolerance,
                                   atol=tolerance)

    def _test_forward_per_channel_cachemask_impl(self, device):
        torch_types = (torch.qint8, torch.quint8)
        float_types = (torch.float32, torch.float16, torch.float64)
        for torch_type, float_type in itertools.product(
                torch_types, float_types):
            X = torch.randn(1, 2, 4, 4, dtype=float_type).to(device)
            # pick the scale + zp so that some values get clipped
            axis = 1
            obs = torch.quantization.PerChannelMinMaxObserver(
                axis, torch_type).to(device)
            obs(X * 0.75)
            scale, zero_point = obs.calculate_qparams()
            # TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast
            zero_point = zero_point.to(torch.int64)
            quant_min, quant_max = obs._calculate_qmin_qmax()

            Y = _fake_quantize_per_channel_affine_reference(
                X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min,
                quant_max)
            Y_prime = torch.fake_quantize_per_channel_affine(
                X, scale, zero_point, axis, quant_min, quant_max)
            np.testing.assert_allclose(Y,
                                       Y_prime.cpu(),
                                       rtol=tolerance,
                                       atol=tolerance)
            self.assertTrue(Y.dtype == float_type)

    def test_forward_per_channel_cachemask_cpu(self):
        self._test_forward_per_channel_cachemask_impl('cpu')

    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
    def test_forward_per_channel_cachemask_cuda(self):
        self._test_forward_per_channel_cachemask_impl('cuda')

    def test_forward_per_channel_half_precision_numerics(self):
        scale = torch.randn(5).abs()
        zero = torch.randn(5).to(dtype=torch.long)
        axis = 1
        mini = 0
        maxi = 255

        for i in range(20):
            X1 = torch.randn(4, 5).to(torch.float16)
            Y1 = torch.fake_quantize_per_channel_affine(
                X1, scale, zero, axis, mini, maxi)
            Y1r = _fake_quantize_per_channel_affine_reference(
                X1, scale, zero, axis, mini, maxi)
            self.assertTrue(
                torch.allclose(Y1, Y1r, rtol=tolerance, atol=tolerance))

        # to force overflow
        X2 = torch.randn(4, 5).to(torch.float16)
        X2[0, 0] = 2**15 + .01
        Y2 = torch.fake_quantize_per_channel_affine(X2, scale, zero, axis,
                                                    mini, maxi)
        Y2r = _fake_quantize_per_channel_affine_reference(
            X2, scale, zero, axis, mini, maxi)
        self.assertTrue(torch.allclose(Y2, Y2r, rtol=tolerance,
                                       atol=tolerance))

        scale = torch.zeros(5) + 10

        # to force underflow
        X3 = torch.randn(4, 5).to(torch.float16)
        X3[0, 0] = 2**-24
        Y3 = torch.fake_quantize_per_channel_affine(X3, scale, zero, axis,
                                                    mini, maxi)
        Y3r = _fake_quantize_per_channel_affine_reference(
            X3, scale, zero, axis, mini, maxi)
        self.assertTrue(torch.allclose(Y3, Y3r, rtol=tolerance,
                                       atol=tolerance))

    def _test_learnable_forward_per_channel(self, X_base, device, scale_base,
                                            zero_point_base, axis):
        r"""Tests the forward path of the learnable FakeQuantizePerTensorAffine op.
        """
        for n_bits in (4, 8):
            quant_min, quant_max = 0, 2**(n_bits) - 1

            scale_base = scale_base.to(device)
            zero_point_base = zero_point_base.to(device)

            X_curr = X_base.clone()
            scale_curr = scale_base.clone()
            zero_point_curr = zero_point_base.clone()

            Y = _fake_quantize_per_channel_affine_reference(
                X_curr, scale_curr,
                zero_point_curr.round().clamp(quant_min, quant_max), axis,
                quant_min, quant_max).to(device)
            for grad_factor in [0.1, 1.0, 10.0]:
                Y_prime = torch._fake_quantize_learnable_per_channel_affine(
                    X_curr, scale_curr, zero_point_curr, axis, quant_min,
                    quant_max, grad_factor).to(device)
                self.assertTrue(
                    torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
                    "Expected kernel forward function to have results match the reference forward function"
                )

    @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(
        1,
        5,
    ),
                                   qparams=hu.qparams(dtypes=torch.quint8)))
    def test_learnable_forward_per_channel_cpu(self, X):
        torch.random.manual_seed(NP_RANDOM_SEED)
        X, (_, _, axis, _) = X
        X_base = torch.tensor(X).to('cpu')
        channel_size = X_base.size(axis)
        scale_base = torch.normal(mean=0, std=1,
                                  size=(channel_size, )).clamp(1e-4, 100)
        zero_point_base = torch.normal(mean=0, std=128, size=(channel_size, ))
        self._test_learnable_forward_per_channel(X_base, 'cpu', scale_base,
                                                 zero_point_base, axis)

    @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(
        1,
        5,
    ),
                                   qparams=hu.qparams(dtypes=torch.quint8)))
    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
    def test_learnable_forward_per_channel_cuda(self, X):
        torch.random.manual_seed(NP_RANDOM_SEED)
        X, (_, _, axis, _) = X
        X_base = torch.tensor(X).to('cuda')
        channel_size = X_base.size(axis)
        scale_base = torch.normal(mean=0, std=1,
                                  size=(channel_size, )).clamp(1e-4, 100)
        zero_point_base = torch.normal(mean=0, std=128, size=(channel_size, ))
        self._test_learnable_forward_per_channel(X_base, 'cuda', scale_base,
                                                 zero_point_base, axis)

    @given(device=st.sampled_from(
        ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
           X=hu.per_channel_tensor(shapes=hu.array_shapes(
               1,
               5,
           ),
                                   qparams=hu.qparams(dtypes=torch.quint8)))
    def test_backward_per_channel(self, device, X):
        r"""Tests the backward method.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        scale = to_tensor(scale, device)
        zero_point = torch.tensor(zero_point).to(dtype=torch.int64,
                                                 device=device)
        X.requires_grad_()
        Y_prime = torch.fake_quantize_per_channel_affine(
            X, scale, zero_point, axis, quant_min, quant_max)
        dout = torch.rand_like(X, dtype=torch.float).to(device)
        dX = _fake_quantize_per_channel_affine_grad_reference(
            dout, X, scale, zero_point, axis, quant_min, quant_max)
        Y_prime.backward(dout)
        np.testing.assert_allclose(dX.cpu().detach().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

    def _test_backward_per_channel_cachemask_impl(self, device):
        torch_types = (torch.qint8, torch.quint8)
        float_types = (torch.float32, torch.float16, torch.float64)
        for torch_type, float_type in itertools.product(
                torch_types, float_types):
            X = torch.randn(1, 2, 4, 4, dtype=float_type).to(device)
            # pick the scale + zp so that some values get clipped
            axis = 1
            obs = torch.quantization.PerChannelMinMaxObserver(
                axis, torch_type).to(device)
            obs(X * 0.75)
            scale, zero_point = obs.calculate_qparams()
            # TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast
            zero_point = zero_point.to(torch.int64)
            quant_min, quant_max = obs._calculate_qmin_qmax()
            X.requires_grad_()
            Y_prime = torch.fake_quantize_per_channel_affine(
                X, scale, zero_point, axis, quant_min, quant_max)
            dout = torch.rand_like(X, dtype=float_type).to(device)
            dX = _fake_quantize_per_channel_affine_grad_reference(
                dout, X, scale, zero_point, axis, quant_min, quant_max)
            Y_prime.backward(dout)
            np.testing.assert_allclose(dX.cpu().detach().numpy(),
                                       X.grad.cpu().detach().numpy(),
                                       rtol=tolerance,
                                       atol=tolerance)
            assert (X.grad.dtype == float_type)

    def test_backward_per_channel_cachemask_cpu(self):
        self._test_backward_per_channel_cachemask_impl('cpu')

    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
    def test_backward_per_channel_cachemask_cuda(self):
        self._test_backward_per_channel_cachemask_impl('cuda')

    def _test_learnable_backward_per_channel(self, X_base, device, scale_base,
                                             zero_point_base, axis):
        r"""Tests the backward path of the learnable FakeQuantizePerTensorAffine op.
        """
        for n_bits in (4, 8):
            quant_min, quant_max = 0, 2**n_bits - 1

            scale_base = scale_base.to(device)
            zero_point_base = zero_point_base.to(device=device)

            X_curr = X_base.clone()
            X_curr.requires_grad_()
            scale_curr = scale_base.clone()
            scale_curr.requires_grad_()
            zero_point_curr = zero_point_base.clone()
            zero_point_curr.requires_grad_()

            for grad_factor in [0.1, 1.0, 10.0]:
                Y_prime = torch._fake_quantize_learnable_per_channel_affine(
                    X_curr, scale_curr, zero_point_curr, axis, quant_min,
                    quant_max, grad_factor).to(device)

                dout = torch.rand(X_curr.shape, dtype=torch.float).to(device)
                dX, dScale, dZeroPoint = _fake_quantize_learnable_per_channel_affine_grad_reference(
                    dout, X_curr, scale_curr, zero_point_curr, axis, quant_min,
                    quant_max, device)
                Y_prime.backward(dout)

                dX_expected = dX.to(device).detach()
                dX_actual = X_curr.to(device).grad.detach()
                dScale_expected = dScale.to(device).detach()
                dScale_actual = scale_curr.to(device).grad.detach()
                dZeroPoint_expected = dZeroPoint.to(device).detach()
                dZeroPoint_actual = zero_point_curr.to(device).grad.detach()
                tolerance = 1e-4

                self.assertTrue(
                    torch.allclose(dX_expected,
                                   dX_actual,
                                   rtol=tolerance,
                                   atol=tolerance),
                    "Expected dX={} to match X.grad={}, X={}, s={}, z={}, dout={}, n_bits={}"
                    .format(dX_expected, dX_actual, X_curr, scale_curr,
                            zero_point_curr, dout, n_bits))
                self.assertTrue(
                    torch.allclose(dScale_expected * grad_factor,
                                   dScale_actual,
                                   rtol=tolerance,
                                   atol=tolerance),
                    "Expected dScale={} to match scale.grad={}, X={}, s={}, z={}, dout={}, n_bits={}"
                    .format(dScale_expected * grad_factor, dScale_actual,
                            X_curr, scale_curr, zero_point_curr, dout, n_bits))
                self.assertTrue(
                    torch.allclose(dZeroPoint_expected * grad_factor,
                                   dZeroPoint_actual,
                                   rtol=tolerance,
                                   atol=tolerance),
                    "Expected dZeroPoint={} to match zero_point.grad={}, X={}, s={}, z={}, dout={}, n_bits={}"
                    .format(dZeroPoint_expected * grad_factor,
                            dZeroPoint_actual, X_curr, scale_curr,
                            zero_point_curr, dout, n_bits))
                X_curr.grad.data.zero_()
                scale_curr.grad.data.zero_()
                zero_point_curr.grad.data.zero_()

    @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(
        2,
        5,
    ),
                                   qparams=hu.qparams(dtypes=torch.quint8)))
    def test_learnable_backward_per_channel_cpu(self, X):
        torch.random.manual_seed(NP_RANDOM_SEED)
        X, (_, _, axis, _) = X
        X_base = torch.tensor(X).to('cpu')
        channel_size = X_base.size(axis)
        scale_base = torch.normal(mean=0, std=1,
                                  size=(channel_size, )).clamp(1e-4, 100)
        zero_point_base = torch.normal(mean=0, std=128, size=(channel_size, ))
        self._test_learnable_backward_per_channel(X_base, 'cpu', scale_base,
                                                  zero_point_base, axis)

    @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(
        2,
        5,
    ),
                                   qparams=hu.qparams(dtypes=torch.quint8)))
    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
    def test_learnable_backward_per_channel_cuda(self, X):
        torch.random.manual_seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        X_base = torch.tensor(X).to('cuda')
        scale_base = to_tensor(scale, 'cuda')
        zero_point_base = to_tensor(zero_point, 'cuda')
        self._test_learnable_backward_per_channel(X_base, 'cuda', scale_base,
                                                  zero_point_base, axis)

    def test_numerical_consistency_per_tensor(self):
        self._test_numerical_consistency('per_tensor')

    def test_numerical_consistency_per_channel(self):
        self._test_numerical_consistency('per_channel')

    def _test_numerical_consistency(self, test_type):
        r"""Comparing numerical consistency between quantize/dequantize op and the fake quantize op across devices and dtypes
        """
        torch.random.manual_seed(NP_RANDOM_SEED)
        torch_types = [torch.qint8, torch.quint8]
        float_types = [torch.float, torch.float16, torch.float64]
        zero_types = [torch.long]
        devices = [torch.device('cpu'),
                   torch.device('cuda')
                   ] if torch.cuda.is_available() else [torch.device('cpu')]
        axis = 1
        for i in range(20):
            for torch_type, float_type, device, zero_type in itertools.product(
                    torch_types, float_types, devices, zero_types):
                X = torch.randn(3, 3, device=device).to(float_type)
                scales = (10 * torch.randn(3, device=device)).abs()
                scale = scales.mean().to(float).item()
                zeros = (10 * torch.randn(3, device=device)).abs().to(
                    dtype=zero_type)
                zero = zeros.max().view(1).item()
                quant_min = torch.iinfo(torch_type).min
                quant_max = torch.iinfo(torch_type).max

                test_was_run = False
                if test_type == "per_tensor":
                    test_was_run = True
                    Y = torch.dequantize(
                        torch.quantize_per_tensor(
                            X.to('cpu').to(torch.float), scale, zero,
                            torch_type)).to(device).to(float_type)
                    Y_prime = torch.fake_quantize_per_tensor_affine(
                        X, scale, zero, quant_min, quant_max)
                    self.assertEqual(
                        Y, Y_prime,
                        "Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor"
                    )

                if test_type == "per_channel":
                    test_was_run = True
                    Y = torch.dequantize(
                        torch.quantize_per_channel(
                            X.to('cpu').to(torch.float), scales.to('cpu'),
                            zeros.to('cpu'), axis,
                            torch_type)).to(device).to(float_type)
                    Y_prime = torch.fake_quantize_per_channel_affine(
                        X, scales, zeros, axis, quant_min, quant_max)
                    self.assertEqual(
                        Y, Y_prime,
                        "Difference found between dequant+quant_per_channel and fake_quantize_per_channel"
                    )
                self.assertTrue(test_was_run)