def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig):
        r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8
        """
        num_lengths = np.random.randint(1, 6)
        lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
        num_indices = np.sum(lengths)
        indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64))

        offsets = lengths_to_offsets(lengths)
        # include the last offset
        offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0)
        weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))

        obs = default_float_qparams_observer()
        obs(weights)
        # Get the scale and zero point for the weight tensor
        qparams = obs.calculate_qparams()
        # Quantize the weights to 8bits
        qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
        qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
                                include_last_offset=True, mode='sum', _weight=qweight)
        qemb(indices, offsets)

        # Ensure the module has the correct weights
        self.assertEqual(qweight, qemb.weight())

        w_packed = qemb._packed_params._packed_weight
        module_out = qemb(indices, offsets)

        # Call the qembedding_bag operator directly
        ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0,
                                                     per_sample_weights=None,
                                                     include_last_offset=True)
        self.assertEqual(module_out, ref)
        self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True)
Beispiel #2
0
 def init(self, embeddingbags, dim, mode, input_size, offset, sparse, include_last_offset, device):
     self.embedding = nnq.EmbeddingBag(
         num_embeddings=embeddingbags,
         embedding_dim=dim,
         mode=mode,
         include_last_offset=include_last_offset).to(device=device)
     numpy.random.seed((1 << 32) - 1)
     self.input = torch.tensor(numpy.random.randint(0, embeddingbags, input_size), device=device).long()
     offset = torch.LongTensor([offset], device=device)
     self.offset = torch.cat((offset, torch.tensor([self.input.size(0)], dtype=torch.long)), 0)
     self.set_module_name('qEmbeddingBag')