Exemplo n.º 1
0
    def test_qat_embedding_bag_errors(self):
        default_qat_qconfig = get_default_qat_qconfig(
            torch.backends.quantized.engine)

        # Test constructor parameters checks here.
        with self.assertRaisesRegex(AssertionError,
                                    "qconfig must be provided for QAT module"):
            nnqat.EmbeddingBag(10, 5, qconfig=None)

        with self.assertRaisesRegex(
                AssertionError,
                "Embedding Bag weights requires a qscheme of " +
                "torch.per_channel_affine_float_qparams"):
            nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)

        # Test from_float checks here.
        embed = nn.Embedding(10, 5)
        with self.assertRaisesRegex(
                AssertionError,
                "qat.EmbeddingBag.from_float only works for EmbeddingBag"):
            nnqat.EmbeddingBag.from_float(embed)
        embed_bag = nn.EmbeddingBag(10, 5)
        with self.assertRaisesRegex(
                AssertionError,
                "Input float module must have qconfig defined"):
            nnqat.EmbeddingBag.from_float(embed_bag)
        embed_bag.qconfig = None
        with self.assertRaisesRegex(
                AssertionError,
                "Input float module must have a valid qconfig"):
            nnqat.EmbeddingBag.from_float(embed_bag)
        embed_bag.qconfig = default_qat_qconfig
        with self.assertRaisesRegex(
                AssertionError,
                "Embedding Bag weights requires a qscheme of " +
                "torch.per_channel_affine_float_qparams"):
            nnqat.EmbeddingBag.from_float(embed_bag)
 def init(self, embeddingbags, dim, mode, input_size, offset, sparse, include_last_offset, device):
     qconfig = default_embedding_qat_qconfig
     self.embedding = nnqat.EmbeddingBag(
         num_embeddings=embeddingbags,
         embedding_dim=dim,
         mode=mode,
         include_last_offset=include_last_offset,
         sparse=sparse, device=device, qconfig=qconfig)
     numpy.random.seed((1 << 32) - 1)
     offsets = torch.LongTensor([offset], device=device)
     input = torch.tensor(numpy.random.randint(0, embeddingbags, input_size), device=device).long()
     self.inputs = {
         "input": input,
         "offset": torch.cat((offsets, torch.tensor([input.size(0)], dtype=torch.long)), 0)
     }
     self.set_module_name('qatEmbeddingBag')