예제 #1
0
    def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig):
        r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8
        """
        num_lengths = np.random.randint(1, 6)
        lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
        num_indices = np.sum(lengths)
        indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64))

        offsets = lengths_to_offsets(lengths)
        # include the last offset
        offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0)
        weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))

        obs = default_float_qparams_observer()
        obs(weights)
        # Get the scale and zero point for the weight tensor
        qparams = obs.calculate_qparams()
        # Quantize the weights to 8bits
        qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
        qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
                                include_last_offset=True, mode='sum', _weight=qweight)
        qemb(indices, offsets)

        # Ensure the module has the correct weights
        self.assertEqual(qweight, qemb.weight())

        w_packed = qemb._packed_params._packed_weight
        module_out = qemb(indices, offsets)

        # Call the qembedding_bag operator directly
        ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0,
                                                     per_sample_weights=None,
                                                     include_last_offset=True)
        self.assertEqual(module_out, ref)
        self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True)
예제 #2
0
    def init(self, num_embeddings: int, embedding_dim: int, num_offsets: int,
             enable_per_sample_weights: bool, include_last_offset: bool,
             is_pruned_weights: bool, use_32bit_indices: bool,
             use_32bit_offsets: bool, op_func):
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.num_offsets = num_offsets
        self.enable_per_sample_weights = enable_per_sample_weights
        self.include_last_offset = include_last_offset
        self.max_segment_length = 20
        self.num_lengths = np.random.randint(1, num_offsets + 1)
        self.lengths = np.random.randint(0,
                                         self.max_segment_length + 1,
                                         size=self.num_lengths).astype(
                                             np.int32)
        self.is_pruned_weights = is_pruned_weights
        self.use_32bit_indices = use_32bit_indices
        self.use_32bit_offsets = use_32bit_offsets

        self.num_indices = np.sum(self.lengths)
        self.offsets = lengths_to_offsets(self.lengths)
        self.indices = torch.from_numpy(
            np.random.randint(low=0,
                              high=num_embeddings,
                              size=self.num_indices,
                              dtype=np.int64))

        self.indices = self.indices.int(
        ) if self.use_32bit_indices else self.indices
        self.offsets = self.offsets.int(
        ) if self.use_32bit_offsets else self.offsets

        if include_last_offset:
            self.offsets = torch.cat(
                (self.offsets,
                 torch.tensor([self.indices.size(0)], dtype=torch.long)), 0)

        self.weights = torch.from_numpy((np.random.random_sample(
            (self.num_embeddings, self.embedding_dim)) + 1).astype(np.float32))
        self.indices = torch.from_numpy(
            np.random.randint(low=0,
                              high=self.num_embeddings,
                              size=self.num_indices,
                              dtype=np.int64))

        self.prepack_func = torch.ops.quantized.embedding_bag_byte_prepack

        self.prepacked_weights = self.prepack_func(self.weights)
        self.per_sample_weights = torch.from_numpy(np.random.uniform(
            low=0.01, high=0.5, size=[len(self.indices)]).astype(np.float32)) if \
            self.enable_per_sample_weights else None

        self.compressed_indices = None

        if self.is_pruned_weights:
            self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping(
                self.prepacked_weights)

        self.op_func = op_func
예제 #3
0
    def test_embedding_bag_api(self, num_embeddings, embedding_dim,
                               num_offsets, set_qconfig):
        r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8
        """
        num_lengths = np.random.randint(1, 6)
        lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
        num_indices = np.sum(lengths)
        indices = torch.from_numpy(
            np.random.randint(low=0,
                              high=num_embeddings,
                              size=num_indices,
                              dtype=np.int64))

        offsets = lengths_to_offsets(lengths)
        # include the last offset
        offsets = torch.cat(
            (offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0)
        weights = torch.from_numpy((np.random.random_sample(
            (num_embeddings, embedding_dim)) + 1).astype(np.float32))

        obs = default_float_qparams_observer()
        obs(weights)
        # Get the scale and zero point for the weight tensor
        qparams = obs.calculate_qparams()
        # Quantize the weights to 8bits
        qweight = torch.quantize_per_channel(weights,
                                             qparams[0],
                                             qparams[1],
                                             axis=0,
                                             dtype=torch.quint8)
        qemb = nnqd.EmbeddingBag(num_embeddings=num_embeddings,
                                 embedding_dim=embedding_dim,
                                 include_last_offset=True,
                                 mode='sum')
        qemb.set_weight(qweight)
        qemb(indices, offsets)

        # Ensure the module has the correct weights
        self.assertEqual(qweight, qemb.weight())

        w_packed = qemb._packed_params._packed_weight
        module_out = qemb(indices, offsets)

        # Call the qembedding_bag operator directly
        ref = torch.ops.quantized.embedding_bag_byte(w_packed,
                                                     indices,
                                                     offsets,
                                                     mode=0,
                                                     per_sample_weights=None,
                                                     include_last_offset=True)
        self.assertEqual(module_out, ref)

        # Test serialization of dynamic EmbeddingBag module using state_dict
        emb_dict = qemb.state_dict()
        b = io.BytesIO()
        torch.save(emb_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        embedding_unpack = torch.ops.quantized.embedding_bag_unpack
        # Check unpacked weight values explicitly
        for key in emb_dict:
            if isinstance(emb_dict[key], torch._C.ScriptObject):
                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
                emb_weight = embedding_unpack(emb_dict[key])
                loaded_weight = embedding_unpack(loaded_dict[key])
                self.assertEqual(emb_weight, loaded_weight)

        # Check state dict serialization and torch.save APIs
        loaded_qemb = nnqd.EmbeddingBag(num_embeddings=num_embeddings,
                                        embedding_dim=embedding_dim,
                                        include_last_offset=True,
                                        mode='sum')
        self.check_eager_serialization(qemb, loaded_qemb, [indices, offsets])

        loaded_qemb.load_state_dict(loaded_dict)
        self.assertEqual(
            embedding_unpack(qemb._packed_params._packed_weight),
            embedding_unpack(loaded_qemb._packed_params._packed_weight))

        # Test JIT serialization
        self.checkScriptable(qemb, [[indices, offsets]], check_save_load=True)

        # Test from_float call
        float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings,
                                                embedding_dim=embedding_dim,
                                                include_last_offset=True,
                                                scale_grad_by_freq=False,
                                                mode='sum')
        if set_qconfig:
            float_embedding.qconfig = float_qparams_dynamic_qconfig

        prepare_dynamic(float_embedding)

        float_embedding(indices, offsets)
        q_embeddingbag = nnqd.EmbeddingBag.from_float(float_embedding)

        q_embeddingbag(indices, offsets)

        self.assertTrue('DynamicQuantizedEmbeddingBag' in str(q_embeddingbag))