示例#1
0
    def test_s_prep_before_qat_prep(self):
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config(
            tq.get_default_qat_qconfig("fbgemm")
        )
        sparsifier.prepare(mod, config=sparse_config)
        tq.prepare_qat(mod, inplace=True)
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
        self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear))
        _squash_mask_calibrate_and_convert(
            mod, sparsifier, torch.randn(1, 4, 4, 4)
        )
        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
示例#2
0
    def test_s_prep_before_qat_prep_fx(self):
        r"""
        This test checks that the ordering of sparse prepare -> prepare_qat_fx -> convert_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask before convert_fx.
        """
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config()
        sparsifier.prepare(mod, config=sparse_config)

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qat_qconfig("fbgemm")
        qconfig_mapping = tq.QConfigMapping() \
            .set_module_name("4", qconfig) \
            .set_module_name("5", qconfig)
        mod = prepare_qat_fx(mod, qconfig_mapping, (example,))

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.qat.LinearReLU))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.weight"))
        mod(example)
        mod = convert_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU))
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
示例#3
0
    def test_qat_prep_before_s_prep(self):
        mod, sparsifier, _ = self._get_model_and_sparsifier_and_sparse_config(
            tq.get_default_qat_qconfig("fbgemm"))
        tq.prepare_qat(mod, inplace=True)

        # need to setup sparse_config on new modules
        sparse_config = [
            {
                "tensor_fqn": "5.weight",
                "sparsity_level": 0.7,
                "sparse_block_shape": (1, 4),
                "zeros_per_block": 4,
            },
            {
                "tensor_fqn": "0.weight"
            },
        ]
        sparsifier.prepare(mod, config=sparse_config)

        # check that correct modules had parametrizations added and
        # that none were lost during qat prepare
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
        self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear))

        self._squash_mask_calibrate_and_convert(mod, sparsifier,
                                                torch.randn(1, 4, 4, 4))

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
        self.assertEqual(
            mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity,
                                      sparse_config[0]["sparsity_level"])
示例#4
0
    def test_qat_embedding_bag_errors(self):
        default_qat_qconfig = get_default_qat_qconfig(
            torch.backends.quantized.engine)

        # Test constructor parameters checks here.
        with self.assertRaisesRegex(AssertionError,
                                    "qconfig must be provided for QAT module"):
            nnqat.EmbeddingBag(10, 5, qconfig=None)

        with self.assertRaisesRegex(
                AssertionError,
                "Embedding Bag weights requires a qscheme of " +
                "torch.per_channel_affine_float_qparams"):
            nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)

        # Test from_float checks here.
        embed = nn.Embedding(10, 5)
        with self.assertRaisesRegex(
                AssertionError,
                "qat.EmbeddingBag.from_float only works for EmbeddingBag"):
            nnqat.EmbeddingBag.from_float(embed)
        embed_bag = nn.EmbeddingBag(10, 5)
        with self.assertRaisesRegex(
                AssertionError,
                "Input float module must have qconfig defined"):
            nnqat.EmbeddingBag.from_float(embed_bag)
        embed_bag.qconfig = None
        with self.assertRaisesRegex(
                AssertionError,
                "Input float module must have a valid qconfig"):
            nnqat.EmbeddingBag.from_float(embed_bag)
        embed_bag.qconfig = default_qat_qconfig
        with self.assertRaisesRegex(
                AssertionError,
                "Embedding Bag weights requires a qscheme of " +
                "torch.per_channel_affine_float_qparams"):
            nnqat.EmbeddingBag.from_float(embed_bag)
示例#5
0
 def _get_model() -> torch.nn.Module:
     m = nn.Sequential(nn.Conv2d(2, 2, 1))
     m.qconfig = get_default_qat_qconfig("fbgemm")
     m = prepare_qat(m)
     return m