def test_per_channel_minmax_observer(self, qdtype, qscheme, ch_axis, reduce_range): # reduce_range cannot be true for symmetric quantization with uint8 if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric: reduce_range = False myobs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) x = torch.tensor([ [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]], [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]], ]) result = myobs(x) self.assertEqual(result, x) qparams = myobs.calculate_qparams() ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]] ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]] per_channel_symmetric_ref_scales = [ [0.04705882, 0.06274509], [0.03921569, 0.0627451], [0.04705882, 0.0627451], [0.05490196, 0.0627451], ] per_channel_affine_ref_scales = [ [0.02352941, 0.04705882], [0.03529412, 0.03137255], [0.03921569, 0.03137255], [0.04313726, 0.04313726], ] per_channel_affine_qint8_zp = [ [-128, -43], [-15, -128], [-26, -128], [-35, -58], ] per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]] self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis]) self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis]) if qscheme == torch.per_channel_symmetric: ref_scales = per_channel_symmetric_ref_scales[ch_axis] ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128] else: ref_scales = per_channel_affine_ref_scales[ch_axis] ref_zero_points = (per_channel_affine_qint8_zp[ch_axis] if qdtype is torch.qint8 else per_channel_affine_quint8_zp[ch_axis]) if reduce_range: ref_scales = [s * 255 / 127 for s in ref_scales] ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] self.assertTrue( torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype))) self.assertTrue( torch.allclose( qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))
def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig): r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8 """ num_lengths = np.random.randint(1, 6) lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) num_indices = np.sum(lengths) indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) offsets = lengths_to_offsets(lengths) # include the last offset offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0) weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)) for qdtype in [torch.quint8, torch.quint4x2]: obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) obs(weights) # Get the scale and zero point for the weight tensor qparams = obs.calculate_qparams() # Quantize the weights to 8bits qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype) qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, mode='sum', _weight=qweight, dtype=qdtype) qemb(indices, offsets) # Ensure the module has the correct weights self.assertEqual(qweight, qemb.weight()) w_packed = qemb._packed_params._packed_weight module_out = qemb(indices, offsets) # Call the qembedding_bag operator directly if qdtype == torch.quint8: ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, per_sample_weights=None, include_last_offset=True) else: ref = torch.ops.quantized.embedding_bag_4bit(w_packed, indices, offsets, mode=0, per_sample_weights=None, include_last_offset=True) self.assertEqual(module_out, ref) self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True, dtype=qdtype)
def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): # reduce_range cannot be true for symmetric quantization with uint8 if qscheme == torch.per_channel_affine_float_qparams: reduce_range = False if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric: reduce_range = False ObserverList = [ PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme), MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5, reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) ] for myobs in ObserverList: # Calculate qparams should work for empty observers qparams = myobs.calculate_qparams() x = torch.tensor([ [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]], [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]], ]) if type(myobs) == MovingAveragePerChannelMinMaxObserver: # Scaling the input tensor to model change in min/max values # across batches result = myobs(0.5 * x) result = myobs(1.5 * x) self.assertEqual(result, 1.5 * x) else: result = myobs(x) self.assertEqual(result, x) qparams = myobs.calculate_qparams() ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]] ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]] per_channel_symmetric_ref_scales = [ [0.04705882, 0.06274509], [0.03921569, 0.0627451], [0.04705882, 0.0627451], [0.05490196, 0.0627451], ] per_channel_affine_ref_scales = [ [0.02352941, 0.04705882], [0.03529412, 0.03137255], [0.03921569, 0.03137255], [0.04313726, 0.04313726], ] per_channel_affine_qint8_zp = [ [-128, -43], [-15, -128], [-26, -128], [-35, -58], ] per_channel_affine_float_qparams_ref_scales = [ [0.0196, 0.0471], [0.0353, 0.0196], [0.0392, 0.0235], [0.0431, 0.0431], ] per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]] self.assertEqual(myobs.min_val, ref_min_vals[ch_axis]) self.assertEqual(myobs.max_val, ref_max_vals[ch_axis]) if qscheme == torch.per_channel_symmetric: ref_scales = per_channel_symmetric_ref_scales[ch_axis] ref_zero_points = [0, 0 ] if qdtype is torch.qint8 else [128, 128] elif qscheme == torch.per_channel_affine_float_qparams: ref_scales = per_channel_affine_float_qparams_ref_scales[ ch_axis] ref_zero_points = [ -1 * ref_min_vals[ch_axis][i] / ref_scales[i] for i in range(len(ref_scales)) ] else: ref_scales = per_channel_affine_ref_scales[ch_axis] ref_zero_points = (per_channel_affine_qint8_zp[ch_axis] if qdtype is torch.qint8 else per_channel_affine_quint8_zp[ch_axis]) if reduce_range: ref_scales = [s * 255 / 127 for s in ref_scales] ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] self.assertTrue( torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), atol=0.0001)) if qscheme == torch.per_channel_affine_float_qparams: self.assertTrue( torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), atol=1)) else: self.assertTrue( torch.allclose( qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype))) # Test for serializability state_dict = myobs.state_dict() b = io.BytesIO() torch.save(state_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in state_dict: self.assertEqual(state_dict[key], loaded_dict[key]) loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) loaded_obs.load_state_dict(loaded_dict) loaded_qparams = loaded_obs.calculate_qparams() self.assertEqual(myobs.min_val, loaded_obs.min_val) self.assertEqual(myobs.max_val, loaded_obs.max_val) self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())