Esempio n. 1
0
    def test_embedding_bag_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualEmbeddingBagLinear().train()
                model = prepare_qat(
                    model, mapping=get_embedding_qat_module_mappings())
                self.checkObservers(model)

                test_only_train_fn(model, self.embed_linear_data_train)
                # make sure not activation_post_process is inserted for EmbeddingBag
                self.assertFalse(hasattr(model, "activation_post_process"))
                # make sure that FakeQuant zero_points are correct dtype
                self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype,
                                 torch.float32)
                self.assertEqual(
                    model.linear.weight_fake_quant.zero_point.dtype,
                    torch.int32)
                model = convert(
                    model,
                    mapping=get_embedding_static_quant_module_mappings())

                def checkQuantized(model):
                    # Make sure EmbeddingBag is now a quantized EmbeddingBag.
                    self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
                    # Also test that Linear has been quantized.
                    self.assertTrue(type(model.linear), nnq.Linear)

                    test_only_eval_fn(model, self.embed_data)
                    self.checkScriptable(model, self.embed_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = ManualEmbeddingBagLinear()
Esempio n. 2
0
    def test_defused_embedding_bag_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = DeFusedEmbeddingBagLinear().train()
                model = prepare_qat(
                    model, mapping=get_embedding_qat_module_mappings())
                self.checkObservers(model)

                test_only_train_fn(model, self.embed_linear_data_train)
                # make sure activation_post_process is inserted after Linear.
                self.assertEqual(type(model.linear.activation_post_process),
                                 FusedMovingAvgObsFakeQuantize)
                # make sure that Embedding has a noop for activation.
                self.assertEqual(type(model.emb.activation_post_process),
                                 NoopObserver)

                model = convert(
                    model,
                    mapping=get_embedding_static_quant_module_mappings())

                def checkQuantized(model):
                    # make sure Embedding is now a QuantizedEmbedding
                    self.assertEqual(type(model.emb), nn.quantized.Embedding)
                    # make sure Linear is now a QuantizedLinear
                    self.assertEqual(type(model.linear), nn.quantized.Linear)

                    test_only_eval_fn(model, self.embed_data)
                    self.checkScriptable(model, self.embed_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)