Esempio n. 1
0
    def test_quantize_per_channel_sub_byte(self):
        """ Tests the per channel quantization scheme for 4-bit qtensors.
        The scale and zero point for this have to be in floating point. """
        r = torch.rand(3, 2, dtype=torch.float) * 4
        scales = torch.tensor([0.2, 0.3, 0.1], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float)
        qr = torch.quantize_per_channel(r, scales, zero_points, 0,
                                        torch.quint4x2)
        dequant_tensor = qr.dequantize()

        def _get_qranges(bit_width):
            if bit_width == 4:
                return 0, 15

        def _quantize_per_channel_sub_byte_ref(data, scales, zero_points, axis,
                                               bit_width):
            dims = data.size()
            data = data.view(-1, dims[axis], np.prod(dims[axis + 1:]))
            qtensor_size = math.ceil(data.numel() / 2)
            res = torch.empty(qtensor_size, dtype=torch.uint8)
            elem_per_byte = 8 / bit_width
            quant_min, quant_max = _get_qranges(bit_width)
            for i in range(data.size()[0]):
                for j in range(data.size()[1]):
                    for k in range(data.size()[2]):
                        inv_scale = 1.0 / scales[j]
                        index = i * data.size()[1] * data.size(
                        )[2] + j * data.size()[2] + k
                        qvalue = np.clip(
                            np.round(data[i][j][k] * inv_scale +
                                     zero_points[j]), quant_min,
                            quant_max).to(dtype=torch.int)
                        res_idx = int(index / elem_per_byte)
                        if (index % elem_per_byte == 0):
                            res[res_idx] = qvalue
                        else:
                            res[res_idx] |= (qvalue << (
                                (index % elem_per_byte) * bit_width))
            return res

        ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 0,
                                                     4)
        self.assertTrue(np.allclose(qr.int_repr(), ref_res))
        self.assertTrue(
            np.allclose(r.numpy(),
                        dequant_tensor.numpy(),
                        atol=1 / np.min(scales.numpy())))

        # Check 4D tensor with non-zero axis.
        r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4
        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
        qr = torch.quantize_per_channel(r,
                                        scales,
                                        zero_points,
                                        axis=1,
                                        dtype=torch.quint4x2)
        ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 1,
                                                     4)
        self.assertTrue(np.allclose(qr.int_repr(), ref_res))
Esempio n. 2
0
    def test_qtensor_unsqueeze(self):
        x = torch.randn((1, 3, 4))
        qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
        qy = qx.unsqueeze(2)
        self.assertEqual(qy.size(), (1, 3, 1, 4))
        qy = qy.squeeze(2)
        self.assertEqual(qy.size(), qx.size())

        # Per channel qtensor
        scales = torch.tensor([1.0])
        zero_points = torch.tensor([0])
        qx = torch.quantize_per_channel(x, scales=scales, zero_points=zero_points, dtype=torch.quint8, axis=0)
        qy = qx.unsqueeze(0)
        self.assertEqual(qy.size(), (1, 1, 3, 4))
        self.assertEqual(qy.q_per_channel_axis(), 1)

        qz = qy.squeeze(0)
        self.assertEqual(qz.size(), x.size())
        self.assertEqual(qz.q_per_channel_axis(), 0)
        with self.assertRaisesRegex(RuntimeError, "Squeeze is only possible on non-axis dimension for Per-Channel"):
            qz = qy.squeeze(1)

        # squeeze without dim specified
        x = torch.randn((3, 1, 2, 1, 4))
        scales = torch.tensor([1.0, 1.0])
        zero_points = torch.tensor([0, 0])
        qx = torch.quantize_per_channel(x, scales=scales, zero_points=zero_points, dtype=torch.quint8, axis=2)
        qz = qx.squeeze()
        self.assertEqual(qz.size(), (3, 2, 4))
        self.assertEqual(qz.q_per_channel_axis(), 1)
        with self.assertRaisesRegex(RuntimeError, "Squeeze is only possible on non-axis dimension for Per-Channel"):
            qz = qy.squeeze()
Esempio n. 3
0
def _quantize_weight(weight: torch.Tensor, weight_qscheme: torch.qscheme,
                     weight_dtype: torch.dtype, weight_scale: torch.Tensor,
                     weight_zero_point: torch.Tensor,
                     weight_axis: torch.Tensor):
    if weight_dtype == torch.float16:
        weight = weight.to(weight_dtype)
        return weight

    if weight_qscheme == torch.per_tensor_affine:
        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
            weight = torch.quantize_per_tensor(weight, weight_scale,
                                               weight_zero_point, weight_dtype)
            return weight
    elif weight_qscheme in [
            torch.per_channel_affine, torch.per_channel_affine_float_qparams
    ]:
        if weight_dtype in [
                torch.quint8, torch.qint8, torch.quint4x2, torch.qint32
        ]:
            weight = torch.quantize_per_channel(
                weight, weight_scale, weight_zero_point, weight_axis.item(),
                weight_dtype)  # type: ignore[arg-type]
            return weight
    raise Exception(
        f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
Esempio n. 4
0
    def test_qtensor_quantize_per_channel(self):
        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
        scales = torch.tensor([0.2, 0.03], dtype=torch.double)
        zero_points = torch.tensor([5, 10], dtype=torch.long)
        axis = 1

        def quantize_c(data, scales, zero_points):
            res = torch.empty((3, 2))
            quant_min, quant_max = 0, 255
            for i in range(3):
                for j in range(2):
                    res[i][j] = np.clip(
                        np.round(data[i][j] / scales[j]) + zero_points[j],
                        quant_min, quant_max)
            return res

        qr = torch.quantize_per_channel(r, scales, zero_points, axis,
                                        torch.quint8)
        rqr = qr.dequantize()
        self.assertTrue(
            np.allclose(qr.int_repr(), quantize_c(r, scales, zero_points)))
        self.assertTrue(
            np.allclose(r.numpy(),
                        rqr.numpy(),
                        atol=2 / np.min(scales.numpy())))
    def _test_quantize_per_channel(self, r, scales, zero_points, axis, float_params):

        def _quantize_per_channel_ref_nd(data, scales, zero_points, float_params):
            dims = data.size()
            data = data.view(-1, dims[axis], np.prod(dims[axis + 1:]))
            res = torch.empty_like(data)
            quant_min, quant_max = 0, 255
            for i in range(res.size()[0]):
                for j in range(res.size()[1]):
                    for k in range(res.size()[2]):
                        if float_params:
                            inv_scale = 1.0 / scales[j]
                            res[i][j][k] = np.clip(
                                np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max)
                        else:
                            res[i][j][k] = np.clip(
                                np.round(data[i][j][k] / scales[j]) + zero_points[j], quant_min, quant_max)
            res = res.view(*dims)
            return res

        contig_format = torch.channels_last if r.ndim == 4 else torch.channels_last_3d
        for memory_format in [torch.contiguous_format, contig_format]:
            ref_res = _quantize_per_channel_ref_nd(r, scales, zero_points, float_params)
            r_contig = r.contiguous(memory_format=memory_format)
            qr = torch.quantize_per_channel(r_contig, scales, zero_points, axis, torch.quint8)
            rqr = qr.dequantize()
            self.assertTrue(np.allclose(qr.int_repr(), ref_res))
            self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy())))
    def _test_pickle_checkpoint_qtensor(self, device):
        with TemporaryFileName() as fname:

            class M(torch.jit.ScriptModule):
                __constants__ = ['fname']

                def __init__(self):
                    super(M, self).__init__()
                    self.fname = fname

                @torch.jit.script_method
                def forward(self, x, y):
                    torch.save((x, y), self.fname)
                    return y

            q = torch.quantize_per_tensor(torch.rand(2, 3, dtype=torch.float),
                                          scale=0.1,
                                          zero_point=10,
                                          dtype=torch.quint8).to(device)
            qc = torch.quantize_per_channel(
                torch.rand(2, 3, dtype=torch.float),
                scales=torch.tensor([0.1, 0.5, 0.01]),
                zero_points=torch.tensor([10, 0, 20]),
                axis=1,
                dtype=torch.quint8).to(device)
            m = M()
            m(q, qc)
            with open(fname, "rb") as handle:
                loaded_q, loaded_qc = torch.load(fname)
                self.assertEqual(loaded_q, q)
                self.assertEqual(loaded_qc, qc)
Esempio n. 7
0
    def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig):
        num_lengths = np.random.randint(1, 6)
        lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
        num_indices = np.sum(lengths)
        indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
        weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))

        obs = default_float_qparams_observer()
        obs(weights)
        qparams = obs.calculate_qparams()
        # Quantize the weights to 8bits
        qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
        qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
        qemb.set_weight(qweight)
        qemb(indices)

        # Ensure the module has the correct weights
        self.assertEqual(qweight, qemb.weight())

        w_packed = qemb._packed_params._packed_weight
        module_out = qemb(indices)

        # Call the qembedding operator directly
        ref = torch.ops.quantized.embedding_byte(w_packed, indices, sparse=False)
        self.assertEqual(module_out, ref)
        self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False, is_emb_bag=False)
Esempio n. 8
0
    def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig):
        r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8
        """
        num_lengths = np.random.randint(1, 6)
        lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
        num_indices = np.sum(lengths)
        indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64))

        offsets = lengths_to_offsets(lengths)
        # include the last offset
        offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0)
        weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))

        obs = default_float_qparams_observer()
        obs(weights)
        # Get the scale and zero point for the weight tensor
        qparams = obs.calculate_qparams()
        # Quantize the weights to 8bits
        qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
        qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
                                include_last_offset=True, mode='sum', _weight=qweight)
        qemb(indices, offsets)

        # Ensure the module has the correct weights
        self.assertEqual(qweight, qemb.weight())

        w_packed = qemb._packed_params._packed_weight
        module_out = qemb(indices, offsets)

        # Call the qembedding_bag operator directly
        ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0,
                                                     per_sample_weights=None,
                                                     include_last_offset=True)
        self.assertEqual(module_out, ref)
        self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True)
Esempio n. 9
0
    def test_conv_api(self, use_bias, per_channel):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128
        stride = (1, 1)
        i_padding = (0, 0)
        dilation = (1, 1)

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        qX = torch.quantize_per_tensor(X,
                                       scale=scale,
                                       zero_point=128,
                                       dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        if per_channel:
            scale_tensor = torch.ones(oC, dtype=torch.double)
            zero_point_tensor = torch.zeros(oC, dtype=torch.long)
            for i in range(len(scale_tensor)):
                scale_tensor[i] = (i + 1.0) / 255.0

            qw = torch.quantize_per_channel(w,
                                            scales=scale_tensor,
                                            zero_points=zero_point_tensor,
                                            axis=0,
                                            dtype=torch.qint8)
        else:
            qw = torch.quantize_per_tensor(w,
                                           scale=scale,
                                           zero_point=0,
                                           dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None
        q_filters_ref = torch.ops.quantized.conv_prepack(
            qw, b, stride, i_padding, dilation, g)

        ref_result = torch.ops.quantized.conv2d(qX, q_filters_ref, stride,
                                                i_padding, dilation, g, scale,
                                                zero_point)

        q_result = torch.nn.quantized.functional.conv2d(qX,
                                                        qw,
                                                        bias=b,
                                                        scale=scale,
                                                        zero_point=zero_point,
                                                        stride=stride,
                                                        padding=i_padding,
                                                        dilation=dilation,
                                                        groups=g,
                                                        dtype=torch.quint8)

        self.assertEqual(ref_result, q_result)
def test_qparams_conversion(tensor, num_bits, distiller_mode, torch_dtype,
                            per_channel, reduce_range):
    if reduce_range:
        if num_bits != 8:
            return True
        if quantization.is_linear_quant_mode_symmetric(
                distiller_mode) and torch_dtype == torch.quint8:
            return True

    # Calculate quantization parameters with Distiller for number of bits BEFORE reduce_range
    signed = distiller_mode != quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED
    distiller_scale, distiller_zp = _get_quant_params_from_tensor(
        tensor, num_bits, distiller_mode, per_channel=per_channel)

    # Convert parameters to PyTorch
    converted_scale, converted_zp = quantization.distiller_qparams_to_pytorch(
        distiller_scale, distiller_zp, num_bits, distiller_mode, torch_dtype,
        reduce_range)

    # Quantize tensor with Distiller
    # If reduce_range is set, then we actually quantize with num_bits-1
    if reduce_range:
        num_bits -= 1
        distiller_scale, distiller_zp = _get_quant_params_from_tensor(
            tensor, num_bits, distiller_mode, per_channel=per_channel)
    restrict = distiller_mode == quantization.LinearQuantMode.SYMMETRIC_RESTRICTED
    clamp_min, clamp_max = quantization.get_quantized_range(
        num_bits, signed=signed, signed_restrict_qrange=restrict)
    distiller_q_t = quantization.linear_quantize_clamp(tensor, distiller_scale,
                                                       distiller_zp, clamp_min,
                                                       clamp_max)

    # Quantize with PyTorch
    if per_channel:
        pytorch_q_t = torch.quantize_per_channel(tensor, converted_scale,
                                                 converted_zp, 0, torch_dtype)
    else:
        pytorch_q_t = torch.quantize_per_tensor(tensor, converted_scale,
                                                converted_zp, torch_dtype)

    # Dequantize
    distiller_q_dq_t = quantization.linear_dequantize(distiller_q_t,
                                                      distiller_scale,
                                                      distiller_zp)
    pytorch_q_dq_t = pytorch_q_t.dequantize()

    # Compare - allow of up to one quantized "bin" between the tensors
    if per_channel:
        for idx, scale in enumerate(converted_scale):
            torch.testing.assert_allclose(distiller_q_dq_t[idx],
                                          pytorch_q_dq_t[idx],
                                          atol=scale,
                                          rtol=1e-05)
    else:
        torch.testing.assert_allclose(pytorch_q_dq_t,
                                      distiller_q_dq_t,
                                      atol=converted_scale,
                                      rtol=1e-05)
Esempio n. 11
0
def _quantize_weight(float_wt, observer):
    wt_scale, wt_zp = observer.calculate_qparams()
    if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
        qweight = torch.quantize_per_tensor(
            float_wt,
            float(wt_scale), int(wt_zp), torch.qint8)
    elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]:
        wt_axis = observer.ch_axis
        qweight = torch.quantize_per_channel(
            float_wt,
            wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, torch.qint8)
    elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
        qweight = torch.quantize_per_channel(
            float_wt,
            wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, torch.quint8)
    else:
        raise ValueError("Unexpected qscheme " + observer.qscheme)
    return qweight
 def __init__(self, per_channel):
     super(SimpleQTensor, self).__init__()
     x = torch.rand(5, 5).float()
     if not per_channel:
         x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8)
     else:
         s = torch.rand(5, dtype=torch.float64) + 0.1
         zp = torch.randint(5, 15, (5, ))
         x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
     self.register_buffer('x', x_q)
Esempio n. 13
0
def _make_conv_test_input(
    batch_size, in_channels_per_group, input_feature_map_size,
    out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale,
    W_zero_point, use_bias, use_channelwise,
):
    in_channels = in_channels_per_group * groups
    out_channels = out_channels_per_group * groups

    (X_value_min, X_value_max) = (0, 4)
    X_init = torch.randint(
        X_value_min, X_value_max,
        (batch_size, in_channels,) + input_feature_map_size)
    X = X_scale * (X_init - X_zero_point).float()
    X_q = torch.quantize_per_tensor(
        X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8)

    W_scale = W_scale * out_channels
    W_zero_point = W_zero_point * out_channels
    # Resize W_scale and W_zero_points arrays equal to out_channels
    W_scale = W_scale[:out_channels]
    W_zero_point = W_zero_point[:out_channels]
    # For testing, we use small values for weights and for activations so that
    # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in
    # qconv implementation and if there is no overflow.
    # In reference we can't exactly match the results with reference.
    # Please see the comment in qconv implementation file
    #   aten/src/ATen/native/quantized/cpu/qconv.cpp for more details.
    (W_value_min, W_value_max) = (-5, 5)
    # The operator expects them in the format
    # (out_channels, in_channels/groups,) + kernel_size
    W_init = torch.randint(
        W_value_min, W_value_max,
        (out_channels, in_channels_per_group,) + kernel_size)
    b_init = torch.randint(0, 10, (out_channels,))

    if use_channelwise:
        W_shape = (-1, 1) + (1,) * len(kernel_size)
        W_scales_tensor = torch.tensor(W_scale, dtype=torch.float)
        W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float)
        W = W_scales_tensor.reshape(*W_shape) * (
            W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float()
        b = X_scale * W_scales_tensor * b_init.float()
        W_q = torch.quantize_per_channel(
            W, W_scales_tensor, W_zero_points_tensor.long(), 0,
            dtype=torch.qint8)
    else:
        W = W_scale[0] * (W_init - W_zero_point[0]).float()
        b = X_scale * W_scale[0] * b_init.float()
        W_q = torch.quantize_per_tensor(
            W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8)

    return (X, X_q, W, W_q, b if use_bias else None)
Esempio n. 14
0
 def test_qtensor_per_channel_load_save(self):
     r = torch.rand(20, 10, dtype=torch.float) * 4 - 2
     scales = torch.rand(10) * 0.02 + 0.01
     zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long)
     # quint32 is not supported yet
     for dtype in [torch.quint8, torch.qint8]:
         qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)
         with tempfile.NamedTemporaryFile() as f:
             # Serializing and Deserializing Tensor
             torch.save(qr, f)
             f.seek(0)
             qr2 = torch.load(f)
             self.assertEqual(qr, qr2)
Esempio n. 15
0
def _quantize_weight(float_wt, observer):
    wt_scale, wt_zp = observer.calculate_qparams()
    if observer.qscheme in [
            torch.per_tensor_symmetric, torch.per_tensor_affine
    ]:
        qweight = torch.quantize_per_tensor(float_wt, float(wt_scale),
                                            int(wt_zp), torch.qint8)
    else:
        qweight = torch.quantize_per_channel(float_wt,
                                             wt_scale.to(torch.double),
                                             wt_zp.to(torch.int64), 0,
                                             torch.qint8)
    return qweight
Esempio n. 16
0
    def _test_numerical_consistency(self, test_type):
        r"""Comparing numerical consistency between quantize/dequantize op and the fake quantize op across devices and dtypes
        """
        torch.random.manual_seed(NP_RANDOM_SEED)
        torch_types = [torch.qint8, torch.quint8]
        float_types = [torch.float, torch.float16, torch.float64]
        zero_types = [torch.long]
        devices = [torch.device('cpu'),
                   torch.device('cuda')
                   ] if torch.cuda.is_available() else [torch.device('cpu')]
        axis = 1
        for i in range(20):
            for torch_type, float_type, device, zero_type in itertools.product(
                    torch_types, float_types, devices, zero_types):
                X = torch.randn(3, 3, device=device).to(float_type)
                scales = (10 * torch.randn(3, device=device)).abs()
                scale = scales.mean().to(float).item()
                zeros = (10 * torch.randn(3, device=device)).abs().to(
                    dtype=zero_type)
                zero = zeros.max().view(1).item()
                quant_min = torch.iinfo(torch_type).min
                quant_max = torch.iinfo(torch_type).max

                test_was_run = False
                if test_type == "per_tensor":
                    test_was_run = True
                    Y = torch.dequantize(
                        torch.quantize_per_tensor(
                            X.to('cpu').to(torch.float), scale, zero,
                            torch_type)).to(device).to(float_type)
                    Y_prime = torch.fake_quantize_per_tensor_affine(
                        X, scale, zero, quant_min, quant_max)
                    self.assertEqual(
                        Y, Y_prime,
                        "Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor"
                    )

                if test_type == "per_channel":
                    test_was_run = True
                    Y = torch.dequantize(
                        torch.quantize_per_channel(
                            X.to('cpu').to(torch.float), scales.to('cpu'),
                            zeros.to('cpu'), axis,
                            torch_type)).to(device).to(float_type)
                    Y_prime = torch.fake_quantize_per_channel_affine(
                        X, scales, zeros, axis, quant_min, quant_max)
                    self.assertEqual(
                        Y, Y_prime,
                        "Difference found between dequant+quant_per_channel and fake_quantize_per_channel"
                    )
                self.assertTrue(test_was_run)
Esempio n. 17
0
 def tensor_creation_ops(self):
     i = torch.tensor([[0, 1, 1], [2, 0, 2]])
     v = torch.tensor([3, 4, 5], dtype=torch.float32)
     real = torch.tensor([1, 2], dtype=torch.float32)
     imag = torch.tensor([3, 4], dtype=torch.float32)
     inp = torch.tensor([-1.5, 0.0, 2.0])
     values = torch.tensor([0.5])
     quantized = torch.quantize_per_channel(
         torch.tensor([[-1.0, 0.0], [1.0, 2.0]]),
         torch.tensor([0.1, 0.01]),
         torch.tensor([10, 0]),
         0,
         torch.quint8,
     )
     return (
         torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]),
         # torch.sparse_coo_tensor(i, v, [2, 3]), # not work for iOS
         torch.as_tensor([1, 2, 3]),
         torch.as_strided(torch.randn(3, 3), (2, 2), (1, 2)),
         torch.zeros(2, 3),
         torch.zeros((2, 3)),
         torch.zeros([2, 3], out=i),
         torch.zeros(5),
         torch.zeros_like(torch.empty(2, 3)),
         torch.ones(2, 3),
         torch.ones((2, 3)),
         torch.ones([2, 3]),
         torch.ones(5),
         torch.ones_like(torch.empty(2, 3)),
         torch.arange(5),
         torch.arange(1, 4),
         torch.arange(1, 2.5, 0.5),
         torch.range(1, 4),
         torch.range(1, 4, 0.5),
         torch.linspace(3.0, 3.0, steps=1),
         torch.logspace(start=2, end=2, steps=1, base=2.0),
         torch.eye(3),
         torch.empty(2, 3),
         torch.empty_like(torch.empty(2, 3), dtype=torch.int64),
         torch.empty_strided((2, 3), (1, 2)),
         torch.full((2, 3), 3.141592),
         torch.full_like(torch.full((2, 3), 3.141592), 2.71828),
         torch.quantize_per_tensor(
             torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8
         ),
         torch.dequantize(quantized),
         torch.complex(real, imag),
         torch.polar(real, imag),
         torch.heaviside(inp, values),
     )
Esempio n. 18
0
    def test_numerical_consistency_per_channel(self, device, X):
        r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        scale = to_tensor(scale, device)
        zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
        # quantize_linear and dequantize are only implemented in CPU
        Y = torch.dequantize(torch.quantize_per_channel(X.cpu(), scale.cpu(), zero_point.cpu(), axis, torch_type))
        Y_prime = torch.fake_quantize_per_channel_affine(
            X, scale, zero_point, axis, quant_min, quant_max)
        np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
Esempio n. 19
0
def _quantize_weight(weight: torch.Tensor, weight_qscheme: torch.qscheme,
                     weight_dtype: torch.dtype, weight_scale: torch.Tensor,
                     weight_zero_point: torch.Tensor,
                     weight_axis: torch.Tensor):
    if weight_qscheme == torch.per_tensor_affine:
        weight = torch.quantize_per_tensor(weight, weight_scale,
                                           weight_zero_point, weight_dtype)
    elif weight_qscheme in [
            torch.per_channel_affine, torch.per_channel_affine_float_qparams
    ]:
        weight = torch.quantize_per_channel(
            weight, weight_scale, weight_zero_point, weight_axis.item(),
            weight_dtype)  # type: ignore[arg-type]
    else:
        raise Exception(f"Unsupported qscheme: {weight_qscheme}")
    return weight
def _quantize_weight(float_wt, observer):
    if observer is None: # allow dummy observer that leads to as-is quantization
        return float_wt
    wt_scale, wt_zp = observer.calculate_qparams()
    if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
        qweight = torch.quantize_per_tensor(
            float_wt,
            float(wt_scale), int(wt_zp), observer.dtype)
    elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]:
        wt_axis = observer.ch_axis
        qweight = torch.quantize_per_channel(
            float_wt,
            wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, observer.dtype)
    else:
        raise ValueError("Unexpected qscheme " + observer.qscheme)
    qweight = qweight.dequantize()
    return qweight
Esempio n. 21
0
    def test_quantize_per_channel_float_qparams(self):
        r = torch.rand(3, 2, dtype=torch.float) * 4
        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
        axis = 1

        # Reference quantize function with FP zero_point.
        def quantize_ref(data, scales, zero_points):
            res = torch.empty((3, 2))
            quant_min, quant_max = 0, 255
            for i in range(3):
                for j in range(2):
                    inv_scale = 1.0 / scales[j]
                    res[i][j] = np.clip(
                        np.round(data[i][j] * inv_scale + zero_points[j]),
                        quant_min, quant_max)
            return res

        qr = torch.quantize_per_channel(r, scales, zero_points, axis,
                                        torch.quint8)
        dequant_tensor = qr.dequantize()
        ref = quantize_ref(r, scales, zero_points)
        self.assertTrue(np.allclose(qr.int_repr(), ref))
        self.assertTrue(np.allclose(r.numpy(), dequant_tensor.numpy(), atol=1))

        # Check 4D tensor with 2 different memory formats.
        r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4
        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
        self._test_quantize_per_channel(r, scales, zero_points, 1, True)

        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
        self._test_quantize_per_channel(r, scales, zero_points, 0, True)

        # Check 5D tensor.
        r = torch.rand(3, 2, 4, 5, 7, dtype=torch.float) * 4 - 2
        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
        self._test_quantize_per_channel(r, scales, zero_points, 1, True)

        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float)
        zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
        self._test_quantize_per_channel(r, scales, zero_points, 0, True)
Esempio n. 22
0
    def test_qtensor_per_channel_permute(self):
        r = torch.rand(20, 10, 2, 2, dtype=torch.float) * 4 - 2
        scales = torch.rand(10) * 0.02 + 0.01
        zero_points = torch.round(torch.rand(10) * 2 - 1).to(torch.long)
        qr = torch.quantize_per_channel(r, scales, zero_points, 1, torch.qint8)

        # we can't reorder the axis
        with self.assertRaises(RuntimeError):
            qr.transpose(0, 1)

        # but we can change memory format
        qlast = qr.contiguous(memory_format=torch.channels_last)
        self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride()))))
        self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride()))))
        self.assertEqual(qr.int_repr(), qlast.int_repr())
        self.assertEqual(scales, qlast.q_per_channel_scales())
        self.assertEqual(zero_points, qlast.q_per_channel_zero_points())
        self.assertEqual(1, qlast.q_per_channel_axis())
        self.assertEqual(qlast.dequantize(), qr.dequantize())
def minmax_symmetric_quantize(weight, min_vals, max_vals):
    """
    Mimic pytorch's _ObserverBase.per_channel_symmetric quantization
    """
    qmax = 127
    qmin = -128

    zero_points = torch.zeros(min_vals.size(), dtype=torch.int64)
    if torch.equal(max_vals, min_vals):
        scales = torch.ones(min_vals.size(), dtype=torch.float)
    else:
        max_vals = torch.max(-min_vals, max_vals)
        scales = max_vals / ((qmax - qmin) / 2)
        scales = torch.max(scales, torch.tensor([1e-8],
                                                device=scales.device,
                                                dtype=scales.dtype))

    return torch.quantize_per_channel(weight.data.cpu(), scales.cpu(),
                                      zero_points, axis=0, dtype=torch.qint8)
Esempio n. 24
0
def _quantize_and_dequantize_weight(weight: torch.Tensor,
                                    weight_qscheme: torch.qscheme,
                                    weight_dtype: torch.dtype,
                                    weight_scale: torch.Tensor,
                                    weight_zero_point: torch.Tensor,
                                    weight_axis: torch.Tensor):
    """ Quantize and then dequantize the weight based on
    the quantization parameters
    """
    if weight_qscheme == torch.per_tensor_affine:
        weight = torch.quantize_per_tensor(weight, weight_scale,
                                           weight_zero_point, weight_dtype)
        weight_dequant = weight.dequantize()
    elif weight_qscheme == torch.per_channel_affine:
        weight = torch.quantize_per_channel(
            weight, weight_scale, weight_zero_point, weight_axis.item(),
            weight_dtype)  # type: ignore[arg-type]
        weight_dequant = weight.dequantize()
    else:
        weight_dequant = weight
    return weight_dequant
def regular_serialization():
    test_cases = {}
    for dtype, device in itertools.product(all_dtypes, all_devices):
        base_name = f'regular_serialization_{dtype_name(dtype)}_{device}'

        test_cases[f'{base_name}_0'] = [
            make_tensor((3, 5), device=device, dtype=dtype, low=-9, high=9)
        ]

        a = make_tensor((15, 5, 5), device=device, dtype=dtype, low=-9, high=9)
        test_cases[f'{base_name}_1'] = [
            get_storage(a),
            a.view((5, 3, 25)),
            a,
            a[1:],
        ]

        if dtype.is_floating_point or dtype.is_complex:
            m = torch.nn.Linear(50, 10, dtype=dtype, device=device)
            test_cases[f'{base_name}_module_0'] = [m]

        # Quantization
        if dtype == torch.float and device == 'cpu':
            for qdtype in [
                    torch.quint8, torch.qint8, torch.qint32, torch.quint4x2
            ]:
                a = make_tensor((10, 3, 8, 2, 4),
                                device=device,
                                dtype=dtype,
                                low=-9,
                                high=9)
                q = torch.quantize_per_tensor(a, 1.0, 2, qdtype)
                test_cases[f'{base_name}_quant_0_{dtype_name(qdtype)}'] = [q]
                test_cases[f'{base_name}_quant_1_{dtype_name(qdtype)}'] = [
                    a, q
                ]

                # TODO: For some reason, qint32 throws an illegal instruction
                # error, for both master and local branch. Either I'm doing
                # something wrong or it's an actual problem. Either way,
                # I should file an issue
                if qdtype == torch.qint32:
                    continue

                a = make_tensor((10, 3, 8, 2, 4),
                                device=device,
                                dtype=dtype,
                                low=-9,
                                high=9)
                scales = make_tensor((8, ),
                                     device=device,
                                     dtype=dtype,
                                     low=-9,
                                     high=9)
                zero_points = make_tensor((8, ),
                                          device=device,
                                          dtype=dtype,
                                          low=-9,
                                          high=9)
                q = torch.quantize_per_channel(a, scales, zero_points, 2,
                                               qdtype)
                test_cases[
                    f'{base_name}_quant_channel_0_{dtype_name(qdtype)}'] = [q]
                test_cases[
                    f'{base_name}_quant_channel_1_{dtype_name(qdtype)}'] = [
                        a, q
                    ]

        # TODO: test sparse COO
        # TODO: test packaging

    return test_cases
Esempio n. 26
0
    def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_fused, per_channel, qengine):
        """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu"""
        if qengine not in torch.backends.quantized.supported_engines:
            return
        if qengine == 'qnnpack':
            if IS_PPC or TEST_WITH_UBSAN:
                return
            per_channel = False
        with override_quantized_engine(qengine):
            W = torch.rand(out_features, in_features).float()
            if per_channel:
                scale_tensor = torch.ones(out_features, dtype=torch.double)
                zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
                for i in range(len(scale_tensor)):
                    scale_tensor[i] = (i + 1.0) / 255.0
                W_q = torch.quantize_per_channel(W,
                                                 scales=scale_tensor,
                                                 zero_points=zero_point_tensor,
                                                 axis=0,
                                                 dtype=torch.qint8)
            else:
                W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)

            X = torch.rand(batch_size, in_features).float()
            X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
            B = torch.rand(out_features).float() if use_bias else None
            scale = 0.5
            zero_point = 3
            if use_fused:
                qlinear = nnq_fused.LinearReLU(in_features, out_features)
            else:
                qlinear = nnq.Linear(in_features, out_features)

            # Run module with default-initialized parameters.
            # This tests that the constructor is correct.
            qlinear(X_q)

            qlinear.set_weight_bias(W_q, B)
            # Simple round-trip test to ensure weight()/set_weight() API
            self.assertEqual(qlinear.weight(), W_q, atol=1e-5)
            W_pack = qlinear._packed_params._packed_params

            qlinear.scale = float(scale)
            qlinear.zero_point = int(zero_point)
            Z_q = qlinear(X_q)
            # Check if the module implementation matches calling the
            # ops directly
            if use_fused:
                Z_ref = torch.ops.quantized.linear_relu(
                    X_q, W_pack, scale, zero_point)

                self.assertTrue('QuantizedLinearReLU' in str(qlinear))
            else:
                Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale,
                                                   zero_point)

                self.assertTrue('QuantizedLinear' in str(qlinear))
            self.assertEqual(Z_ref, Z_q)

            # Test serialization of quantized Linear Module using state_dict
            model_dict = qlinear.state_dict()
            self.assertEqual(model_dict['_packed_params.weight'], W_q)
            if use_bias:
                self.assertEqual(model_dict['_packed_params.bias'], B)
            b = io.BytesIO()
            torch.save(model_dict, b)
            b.seek(0)
            loaded_dict = torch.load(b)
            for key in model_dict:
                self.assertEqual(model_dict[key], loaded_dict[key])
            if use_fused:
                loaded_qlinear = nnq_fused.LinearReLU(in_features,
                                                      out_features)
            else:
                loaded_qlinear = nnq.Linear(in_features, out_features)
            loaded_qlinear.load_state_dict(loaded_dict)

            linear_unpack = torch.ops.quantized.linear_unpack
            self.assertEqual(
                linear_unpack(qlinear._packed_params._packed_params),
                linear_unpack(loaded_qlinear._packed_params._packed_params))
            if use_bias:
                self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
            self.assertEqual(qlinear.scale, loaded_qlinear.scale)
            self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
            self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
            self.assertTrue(hasattr(qlinear, '_packed_params'))
            self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
            self.assertTrue(hasattr(qlinear, '_weight_bias'))
            self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))
            self.assertEqual(qlinear._weight_bias(),
                             loaded_qlinear._weight_bias())
            self.assertEqual(
                qlinear._weight_bias(),
                torch.ops.quantized.linear_unpack(
                    qlinear._packed_params._packed_params))
            Z_q2 = loaded_qlinear(X_q)
            self.assertEqual(Z_q, Z_q2)

            # The below check is meant to ensure that `torch.save` and `torch.load`
            # serialization works, however it is currently broken by the following:
            # https://github.com/pytorch/pytorch/issues/24045
            #
            # Instead, we currently check that the proper exception is thrown on save.
            # <start code>
            # b = io.BytesIO()
            # torch.save(qlinear, b)
            # b.seek(0)
            # loaded = torch.load(b)
            # self.assertEqual(qlinear.weight(), loaded.weight())
            # self.assertEqual(qlinear.scale, loaded.scale)
            # self.assertEqual(qlinear.zero_point, loaded.zero_point)
            # <end code>
            with self.assertRaisesRegex(
                    RuntimeError,
                    r'torch.save\(\) is not currently supported'):
                b = io.BytesIO()
                torch.save(qlinear, b)

            # Test JIT
            self.checkScriptable(qlinear,
                                 list(zip([X_q], [Z_ref])),
                                 check_save_load=True)

            # Test from_float.
            float_linear = torch.nn.Linear(in_features, out_features).float()
            float_linear.qconfig = torch.quantization.default_qconfig
            torch.quantization.prepare(float_linear, inplace=True)
            float_linear(X.float())
            # Sequential allows swapping using "convert".
            quantized_float_linear = torch.nn.Sequential(float_linear)
            quantized_float_linear = torch.quantization.convert(
                quantized_float_linear, inplace=True)

            # Smoke test to make sure the module actually runs
            quantized_float_linear(X_q)

            # Smoke test extra_repr
            self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
Esempio n. 27
0
    def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel):
        """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu"""
        if torch.backends.quantized.engine == 'qnnpack':
            per_channel = False
        W = torch.rand(out_features, in_features).float()
        if per_channel:
            scale_tensor = torch.ones(out_features, dtype=torch.double)
            zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
            for i in range(len(scale_tensor)):
                scale_tensor[i] = (i + 1.0) / 255.0
            W_q = torch.quantize_per_channel(W, scales=scale_tensor,
                                             zero_points=zero_point_tensor,
                                             axis=0, dtype=torch.qint8)
        else:
            W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)

        X = torch.rand(batch_size, in_features).float()
        X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
        B = torch.rand(out_features).float() if use_bias else None
        scale = 0.5
        zero_point = 3
        if use_fused:
            qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            qlinear = nnq.Linear(in_features, out_features)

        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X_q)

        qlinear.set_weight_bias(W_q, B)
        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0)
        W_pack = qlinear._packed_params._packed_params

        qlinear.scale = float(scale)
        qlinear.zero_point = int(zero_point)
        Z_q = qlinear(X_q)
        # Check if the module implementation matches calling the
        # ops directly
        if use_fused:
            Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point)

            self.assertTrue('QuantizedLinearReLU' in str(qlinear))
        else:
            Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)

            self.assertTrue('QuantizedLinear' in str(qlinear))
        self.assertEqual(Z_ref, Z_q)

        # Test serialization of quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            if isinstance(model_dict[key], torch._C.ScriptObject):
                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
                w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
                w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
                self.assertEqual(w_model, w_loaded)
                self.assertEqual(b_model, b_loaded)
            else:
                self.assertEqual(model_dict[key], loaded_dict[key])
        if use_fused:
            loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            loaded_qlinear = nnq.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
                         linear_unpack(loaded_qlinear._packed_params._packed_params))
        if use_bias:
            self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
        self.assertEqual(qlinear.scale, loaded_qlinear.scale)
        self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_params'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
        self.assertTrue(hasattr(qlinear, '_weight_bias'))
        self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))
        self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
        self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
        Z_q2 = loaded_qlinear(X_q)
        self.assertEqual(Z_q, Z_q2)

        b = io.BytesIO()
        torch.save(qlinear, b)
        b.seek(0)
        loaded = torch.load(b)
        self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.scale, loaded.scale)
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test JIT
        self.checkScriptable(qlinear, [[X_q]], check_save_load=True)

        # Test from_float.
        float_linear = torch.nn.Linear(in_features, out_features).float()
        float_linear.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_linear, inplace=True)
        float_linear(X.float())
        # Sequential allows swapping using "convert".
        quantized_float_linear = torch.nn.Sequential(float_linear)
        quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X_q)

        # Smoke test extra_repr
        self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
Esempio n. 28
0
    def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel):
        if torch.backends.quantized.engine == 'qnnpack':
            per_channel = False

        # use_fused -> quantized class
        class_map = {
            True: nniq.LinearReLU,
            False: nnq.Linear,
        }

        W = torch.rand(out_features, in_features).float()
        if per_channel:
            scale_tensor = torch.ones(out_features, dtype=torch.double)
            zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
            for i in range(len(scale_tensor)):
                scale_tensor[i] = (i + 1.0) / 255.0
            W_q = torch.quantize_per_channel(W, scales=scale_tensor,
                                             zero_points=zero_point_tensor,
                                             axis=0, dtype=torch.qint8)
        else:
            W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)

        X = torch.rand(batch_size, in_features).float()
        X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
        B = torch.rand(out_features).float() if use_bias else None
        scale = 0.5
        zero_point = 3
        qlinear = class_map[use_fused](in_features, out_features)

        qlinear_copy = copy.deepcopy(qlinear)
        self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True)
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X_q)

        qlinear.set_weight_bias(W_q, B)
        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0)

        # testing packed param implementation
        qlinear.scale = float(scale)
        qlinear.zero_point = int(zero_point)
        Z_q = qlinear(X_q)

        # Check if the module implementation matches calling the
        # ops directly
        W_pack = qlinear._packed_params._packed_params
        if use_fused:
            Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point)
        else:
            Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)

        self.assertEqual(Z_ref, Z_q)
        self.assertTrue(
            ("QuantizedLinearReLU" if use_fused else "QuantizedLinear") in str(qlinear))

        # Test serialization of quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            if isinstance(model_dict[key], torch._C.ScriptObject):
                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
                w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
                w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
                self.assertEqual(w_model, w_loaded)
                self.assertEqual(b_model, b_loaded)
            else:
                self.assertEqual(model_dict[key], loaded_dict[key])

        loaded_qlinear = class_map[use_fused](
            in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)
        linear_unpack = torch.ops.quantized.linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
                         linear_unpack(loaded_qlinear._packed_params._packed_params))
        self.assertEqual(qlinear.scale, loaded_qlinear.scale)
        self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
        # scripting will add __overloads__ to __dict__, which is why we script a copy
        # to be able to do the check in the next line
        self.checkScriptable(copy.deepcopy(loaded_qlinear), [[X_q]], check_save_load=True)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
        self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
        Z_q2 = loaded_qlinear(X_q)
        self.assertEqual(Z_q, Z_q2)

        # Test serialization
        b = io.BytesIO()
        torch.save(qlinear, b)
        b.seek(0)
        loaded = torch.load(b)
        self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.scale, loaded.scale)
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test copy and deepcopy
        copied_linear = copy.copy(qlinear)
        self.assertEqual(copied_linear.bias(), qlinear.bias())
        self.assertEqual(copied_linear.scale, qlinear.scale)
        self.assertEqual(copied_linear.zero_point,
                         qlinear.zero_point)
        Y_copied = copied_linear(X_q)
        np.testing.assert_array_almost_equal(
            Z_q.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)

        deepcopied_linear = copy.deepcopy(qlinear)
        self.assertEqual(deepcopied_linear.bias(), qlinear.bias())
        self.assertEqual(deepcopied_linear.scale, qlinear.scale)
        self.assertEqual(deepcopied_linear.zero_point,
                         qlinear.zero_point)
        Y_deepcopied = copied_linear(X_q)
        np.testing.assert_array_almost_equal(
            Z_q.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)

        # Test JIT
        self.checkScriptable(qlinear, [[X_q]], check_save_load=True)

        # Make sure `from_float` works for all linear variants
        modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]

        for mut in modules_under_test:
            # Test from_float.
            float_linear = mut(in_features, out_features).float()
            float_linear.qconfig = torch.quantization.default_qconfig
            torch.quantization.prepare(float_linear, inplace=True)
            float_linear(X.float())
            # Sequential allows swapping using "convert".
            quantized_float_linear = torch.nn.Sequential(float_linear)
            quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True)

            # Smoke test to make sure the module actually runs
            quantized_float_linear(X_q)

            # Smoke test extra_repr
            self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
    def test_serialize_graph(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)
                self.e = torch.rand(4)
                self.conv = torch.nn.Conv2d(3, 3, 2, bias=False)

            def forward(self, a, b, c):
                add_1 = a + b
                conv1 = self.conv(c)
                linear = self.linear(add_1 + conv1)
                add_2 = linear + self.e
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        c = torch.rand(3, 3, 2, 2)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b, c])

        partitioner = Partitioner()
        devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)]
        partitioner_config = PartitionerConfig(devices,
                                               PartitionMode.sparse_nn)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        # Fix for now to add type/shape to output
        for node in traced.graph.nodes:
            if node.op == "output":
                node.meta['tensor_meta'] = extract_tensor_metadata(a)
        for mod in module_with_submodules.modules():
            if isinstance(mod, GraphModule):
                for node in mod.graph.nodes:
                    node.meta['tensor_meta'] = extract_tensor_metadata(a)
        for node in module_with_submodules.graph.nodes:
            node.meta['tensor_meta'] = extract_tensor_metadata(a)

        weights1 = {}
        weights2 = {}
        serialized_graph1 = graph_manipulation.serialize_module(
            traced, weights1)
        serialized_graph2 = graph_manipulation.serialize_module(
            module_with_submodules, weights2)
        assert len(weights1) == 4
        assert len(weights2) == 4
        assert len(serialized_graph1["nodes"]) == 10
        assert len(serialized_graph1["weights"]) == 4
        assert len(serialized_graph1["modules"]) == 0
        assert len(serialized_graph2["nodes"]) == 6
        assert len(serialized_graph2["weights"]) == 4
        assert len(serialized_graph2["modules"]) == 1
        assert serialized_graph1["weights"]["linear.weight"][
            "shape"] == "[4, 4]"
        assert (serialized_graph1["weights"]["linear.weight"]["dtype"] ==
                "torch.float32")
        assert (serialized_graph1["weights"]["linear.weight"]["is_quantized"]
                is False)
        assert serialized_graph1["nodes"][0]["shape"] == "[4]"
        assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32"
        assert serialized_graph1["nodes"][0]["target"] == "a"
        assert serialized_graph1["nodes"][0]["op_code"] == "placeholder"
        assert serialized_graph1["nodes"][0]["name"] == "a"
        assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_1"
        assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True

        # Test quantization info serialization.
        x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
        q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32)
        q_tensor_channel = torch.quantize_per_channel(
            x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0,
            torch.quint8)
        result = graph_manipulation.serialize_tensor_quantization(q_tensor)
        result2 = graph_manipulation.serialize_tensor_quantization(
            q_tensor_channel)
        assert result["qscheme"] == "torch.per_tensor_affine"
        assert result["q_scale"] == 1.0
        assert result2["qscheme"] == "torch.per_channel_affine"
        assert len(result2["q_per_channel_scales"]) == 2
Esempio n. 30
0
                             dtype=torch.int64))  # E: {Tensor}
reveal_type(torch.empty_strided((2, 3), (1, 2)))  # E: {Tensor}

# torch.full/full_like
reveal_type(torch.full((2, 3), 3.141592))  # E: {Tensor}
reveal_type(torch.full_like(torch.full((2, 3), 3.141592),
                            2.71828))  # E: {Tensor}

# torch.quantize_per_tensor
reveal_type(
    torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10,
                              torch.quint8))  # E: {Tensor}

# torch.quantize_per_channel
x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
quant = torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]),
                                   torch.tensor([10, 0]), 0, torch.quint8)
reveal_type(x)  # E: {Tensor}

# torch.dequantize
reveal_type(torch.dequantize(x))  # E: {Tensor}

# torch.complex
real = torch.tensor([1, 2], dtype=torch.float32)
imag = torch.tensor([3, 4], dtype=torch.float32)
reveal_type(torch.complex(real, imag))  # E: {Tensor}

# torch.polar
abs = torch.tensor([1, 2], dtype=torch.float64)
pi = torch.acos(torch.zeros(1)).item() * 2
angle = torch.tensor([pi / 2, 5 * pi / 4], dtype=torch.float64)
reveal_type(torch.polar(abs, angle))  # E: {Tensor}