예제 #1
0
    def test_per_channel_minmax_observer(self, qdtype, qscheme, ch_axis,
                                         reduce_range):
        # reduce_range cannot be true for symmetric quantization with uint8
        if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
            reduce_range = False
        myobs = PerChannelMinMaxObserver(reduce_range=reduce_range,
                                         ch_axis=ch_axis,
                                         dtype=qdtype,
                                         qscheme=qscheme)
        x = torch.tensor([
            [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]],
            [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
        ])
        result = myobs(x)
        self.assertEqual(result, x)
        qparams = myobs.calculate_qparams()
        ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]]
        ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]]
        per_channel_symmetric_ref_scales = [
            [0.04705882, 0.06274509],
            [0.03921569, 0.0627451],
            [0.04705882, 0.0627451],
            [0.05490196, 0.0627451],
        ]
        per_channel_affine_ref_scales = [
            [0.02352941, 0.04705882],
            [0.03529412, 0.03137255],
            [0.03921569, 0.03137255],
            [0.04313726, 0.04313726],
        ]
        per_channel_affine_qint8_zp = [
            [-128, -43],
            [-15, -128],
            [-26, -128],
            [-35, -58],
        ]
        per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]]

        self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis])
        self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis])
        if qscheme == torch.per_channel_symmetric:
            ref_scales = per_channel_symmetric_ref_scales[ch_axis]
            ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128]
        else:
            ref_scales = per_channel_affine_ref_scales[ch_axis]
            ref_zero_points = (per_channel_affine_qint8_zp[ch_axis]
                               if qdtype is torch.qint8 else
                               per_channel_affine_quint8_zp[ch_axis])

        if reduce_range:
            ref_scales = [s * 255 / 127 for s in ref_scales]
            ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]

        self.assertTrue(
            torch.allclose(qparams[0],
                           torch.tensor(ref_scales, dtype=qparams[0].dtype)))
        self.assertTrue(
            torch.allclose(
                qparams[1],
                torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))
예제 #2
0
    def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig):
        r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8
        """

        num_lengths = np.random.randint(1, 6)
        lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
        num_indices = np.sum(lengths)
        indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64))

        offsets = lengths_to_offsets(lengths)
        # include the last offset
        offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0)
        weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))

        for qdtype in [torch.quint8, torch.quint4x2]:
            obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
            obs(weights)
            # Get the scale and zero point for the weight tensor
            qparams = obs.calculate_qparams()
            # Quantize the weights to 8bits
            qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype)
            qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
                                    include_last_offset=True, mode='sum', _weight=qweight, dtype=qdtype)
            qemb(indices, offsets)

            # Ensure the module has the correct weights
            self.assertEqual(qweight, qemb.weight())

            w_packed = qemb._packed_params._packed_weight
            module_out = qemb(indices, offsets)

            # Call the qembedding_bag operator directly
            if qdtype == torch.quint8:
                ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0,
                                                             per_sample_weights=None,
                                                             include_last_offset=True)
            else:
                ref = torch.ops.quantized.embedding_bag_4bit(w_packed, indices, offsets, mode=0,
                                                             per_sample_weights=None,
                                                             include_last_offset=True)

            self.assertEqual(module_out, ref)
            self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices,
                                             offsets, set_qconfig, is_emb_bag=True, dtype=qdtype)
예제 #3
0
    def test_per_channel_observers(self, qdtype, qscheme, ch_axis,
                                   reduce_range):
        # reduce_range cannot be true for symmetric quantization with uint8
        if qscheme == torch.per_channel_affine_float_qparams:
            reduce_range = False
        if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
            reduce_range = False
        ObserverList = [
            PerChannelMinMaxObserver(reduce_range=reduce_range,
                                     ch_axis=ch_axis,
                                     dtype=qdtype,
                                     qscheme=qscheme),
            MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5,
                                                  reduce_range=reduce_range,
                                                  ch_axis=ch_axis,
                                                  dtype=qdtype,
                                                  qscheme=qscheme)
        ]

        for myobs in ObserverList:
            # Calculate qparams should work for empty observers
            qparams = myobs.calculate_qparams()
            x = torch.tensor([
                [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]],
                [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
            ])
            if type(myobs) == MovingAveragePerChannelMinMaxObserver:
                # Scaling the input tensor to model change in min/max values
                # across batches
                result = myobs(0.5 * x)
                result = myobs(1.5 * x)
                self.assertEqual(result, 1.5 * x)
            else:
                result = myobs(x)
                self.assertEqual(result, x)

            qparams = myobs.calculate_qparams()
            ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0],
                            [-4.0, -3.0]]
            ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]]
            per_channel_symmetric_ref_scales = [
                [0.04705882, 0.06274509],
                [0.03921569, 0.0627451],
                [0.04705882, 0.0627451],
                [0.05490196, 0.0627451],
            ]
            per_channel_affine_ref_scales = [
                [0.02352941, 0.04705882],
                [0.03529412, 0.03137255],
                [0.03921569, 0.03137255],
                [0.04313726, 0.04313726],
            ]
            per_channel_affine_qint8_zp = [
                [-128, -43],
                [-15, -128],
                [-26, -128],
                [-35, -58],
            ]
            per_channel_affine_float_qparams_ref_scales = [
                [0.0196, 0.0471],
                [0.0353, 0.0196],
                [0.0392, 0.0235],
                [0.0431, 0.0431],
            ]
            per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0],
                                            [93, 70]]

            self.assertEqual(myobs.min_val, ref_min_vals[ch_axis])
            self.assertEqual(myobs.max_val, ref_max_vals[ch_axis])
            if qscheme == torch.per_channel_symmetric:
                ref_scales = per_channel_symmetric_ref_scales[ch_axis]
                ref_zero_points = [0, 0
                                   ] if qdtype is torch.qint8 else [128, 128]
            elif qscheme == torch.per_channel_affine_float_qparams:
                ref_scales = per_channel_affine_float_qparams_ref_scales[
                    ch_axis]
                ref_zero_points = [
                    -1 * ref_min_vals[ch_axis][i] / ref_scales[i]
                    for i in range(len(ref_scales))
                ]
            else:
                ref_scales = per_channel_affine_ref_scales[ch_axis]
                ref_zero_points = (per_channel_affine_qint8_zp[ch_axis]
                                   if qdtype is torch.qint8 else
                                   per_channel_affine_quint8_zp[ch_axis])

            if reduce_range:
                ref_scales = [s * 255 / 127 for s in ref_scales]
                ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]
            self.assertTrue(
                torch.allclose(qparams[0],
                               torch.tensor(ref_scales,
                                            dtype=qparams[0].dtype),
                               atol=0.0001))
            if qscheme == torch.per_channel_affine_float_qparams:
                self.assertTrue(
                    torch.allclose(qparams[1],
                                   torch.tensor(ref_zero_points,
                                                dtype=qparams[1].dtype),
                                   atol=1))
            else:
                self.assertTrue(
                    torch.allclose(
                        qparams[1],
                        torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))

            # Test for serializability
            state_dict = myobs.state_dict()
            b = io.BytesIO()
            torch.save(state_dict, b)
            b.seek(0)
            loaded_dict = torch.load(b)
            for key in state_dict:
                self.assertEqual(state_dict[key], loaded_dict[key])
            loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range,
                                                  ch_axis=ch_axis,
                                                  dtype=qdtype,
                                                  qscheme=qscheme)
            loaded_obs.load_state_dict(loaded_dict)
            loaded_qparams = loaded_obs.calculate_qparams()
            self.assertEqual(myobs.min_val, loaded_obs.min_val)
            self.assertEqual(myobs.max_val, loaded_obs.max_val)
            self.assertEqual(myobs.calculate_qparams(),
                             loaded_obs.calculate_qparams())