def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig): r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8 """ num_lengths = np.random.randint(1, 6) lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) num_indices = np.sum(lengths) indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) offsets = lengths_to_offsets(lengths) # include the last offset offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0) weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)) obs = default_float_qparams_observer() obs(weights) # Get the scale and zero point for the weight tensor qparams = obs.calculate_qparams() # Quantize the weights to 8bits qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, mode='sum', _weight=qweight) qemb(indices, offsets) # Ensure the module has the correct weights self.assertEqual(qweight, qemb.weight()) w_packed = qemb._packed_params._packed_weight module_out = qemb(indices, offsets) # Call the qembedding_bag operator directly ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, per_sample_weights=None, include_last_offset=True) self.assertEqual(module_out, ref) self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True)
def init(self, num_embeddings: int, embedding_dim: int, num_offsets: int, enable_per_sample_weights: bool, include_last_offset: bool, is_pruned_weights: bool, use_32bit_indices: bool, use_32bit_offsets: bool, op_func): self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.num_offsets = num_offsets self.enable_per_sample_weights = enable_per_sample_weights self.include_last_offset = include_last_offset self.max_segment_length = 20 self.num_lengths = np.random.randint(1, num_offsets + 1) self.lengths = np.random.randint(0, self.max_segment_length + 1, size=self.num_lengths).astype( np.int32) self.is_pruned_weights = is_pruned_weights self.use_32bit_indices = use_32bit_indices self.use_32bit_offsets = use_32bit_offsets self.num_indices = np.sum(self.lengths) self.offsets = lengths_to_offsets(self.lengths) self.indices = torch.from_numpy( np.random.randint(low=0, high=num_embeddings, size=self.num_indices, dtype=np.int64)) self.indices = self.indices.int( ) if self.use_32bit_indices else self.indices self.offsets = self.offsets.int( ) if self.use_32bit_offsets else self.offsets if include_last_offset: self.offsets = torch.cat( (self.offsets, torch.tensor([self.indices.size(0)], dtype=torch.long)), 0) self.weights = torch.from_numpy((np.random.random_sample( (self.num_embeddings, self.embedding_dim)) + 1).astype(np.float32)) self.indices = torch.from_numpy( np.random.randint(low=0, high=self.num_embeddings, size=self.num_indices, dtype=np.int64)) self.prepack_func = torch.ops.quantized.embedding_bag_byte_prepack self.prepacked_weights = self.prepack_func(self.weights) self.per_sample_weights = torch.from_numpy(np.random.uniform( low=0.01, high=0.5, size=[len(self.indices)]).astype(np.float32)) if \ self.enable_per_sample_weights else None self.compressed_indices = None if self.is_pruned_weights: self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping( self.prepacked_weights) self.op_func = op_func
def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig): r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8 """ num_lengths = np.random.randint(1, 6) lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) num_indices = np.sum(lengths) indices = torch.from_numpy( np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) offsets = lengths_to_offsets(lengths) # include the last offset offsets = torch.cat( (offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0) weights = torch.from_numpy((np.random.random_sample( (num_embeddings, embedding_dim)) + 1).astype(np.float32)) obs = default_float_qparams_observer() obs(weights) # Get the scale and zero point for the weight tensor qparams = obs.calculate_qparams() # Quantize the weights to 8bits qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) qemb = nnqd.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, mode='sum') qemb.set_weight(qweight) qemb(indices, offsets) # Ensure the module has the correct weights self.assertEqual(qweight, qemb.weight()) w_packed = qemb._packed_params._packed_weight module_out = qemb(indices, offsets) # Call the qembedding_bag operator directly ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, per_sample_weights=None, include_last_offset=True) self.assertEqual(module_out, ref) # Test serialization of dynamic EmbeddingBag module using state_dict emb_dict = qemb.state_dict() b = io.BytesIO() torch.save(emb_dict, b) b.seek(0) loaded_dict = torch.load(b) embedding_unpack = torch.ops.quantized.embedding_bag_unpack # Check unpacked weight values explicitly for key in emb_dict: if isinstance(emb_dict[key], torch._C.ScriptObject): assert isinstance(loaded_dict[key], torch._C.ScriptObject) emb_weight = embedding_unpack(emb_dict[key]) loaded_weight = embedding_unpack(loaded_dict[key]) self.assertEqual(emb_weight, loaded_weight) # Check state dict serialization and torch.save APIs loaded_qemb = nnqd.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, mode='sum') self.check_eager_serialization(qemb, loaded_qemb, [indices, offsets]) loaded_qemb.load_state_dict(loaded_dict) self.assertEqual( embedding_unpack(qemb._packed_params._packed_weight), embedding_unpack(loaded_qemb._packed_params._packed_weight)) # Test JIT serialization self.checkScriptable(qemb, [[indices, offsets]], check_save_load=True) # Test from_float call float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, scale_grad_by_freq=False, mode='sum') if set_qconfig: float_embedding.qconfig = float_qparams_dynamic_qconfig prepare_dynamic(float_embedding) float_embedding(indices, offsets) q_embeddingbag = nnqd.EmbeddingBag.from_float(float_embedding) q_embeddingbag(indices, offsets) self.assertTrue('DynamicQuantizedEmbeddingBag' in str(q_embeddingbag))