def test_s_prep_before_qat_prep(self): ( mod, sparsifier, sparse_config, ) = _get_model_and_sparsifier_and_sparse_config( tq.get_default_qat_qconfig("fbgemm") ) sparsifier.prepare(mod, config=sparse_config) tq.prepare_qat(mod, inplace=True) self.assertTrue(hasattr(mod[0], "parametrizations")) self.assertTrue(hasattr(mod[5], "parametrizations")) # check that correct observers were inserted and that matching # occured successfully self.assertTrue(hasattr(mod[5], "activation_post_process")) self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear)) _squash_mask_calibrate_and_convert( mod, sparsifier, torch.randn(1, 4, 4, 4) ) # check that final module is the expected quantized module and that the model runs self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear)) self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
def test_s_prep_before_qat_prep_fx(self): r""" This test checks that the ordering of sparse prepare -> prepare_qat_fx -> convert_fx compose cleanly without issue and that the final result is sparsified without having to call squash mask before convert_fx. """ ( mod, sparsifier, sparse_config, ) = _get_model_and_sparsifier_and_sparse_config() sparsifier.prepare(mod, config=sparse_config) example = torch.randn(1, 4, 4, 4) qconfig = tq.get_default_qat_qconfig("fbgemm") qconfig_mapping = tq.QConfigMapping() \ .set_module_name("4", qconfig) \ .set_module_name("5", qconfig) mod = prepare_qat_fx(mod, qconfig_mapping, (example,)) # check that correct modules had parametrizations added and # that none were lost during prepare self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations")) self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.qat.LinearReLU)) # check that correct observers were inserted and that matching # occured successfully self.assertTrue(_module_has_activation_post_process(mod, "5")) sparsifier.step() sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.weight")) mod(example) mod = convert_fx(mod) # check that final module is the expected quantized module and that the model runs self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU)) self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0]) self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) self.assertGreaterAlmostEqual( sparsity_level, sparse_config[0]["sparsity_level"] ) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
def test_qat_prep_before_s_prep(self): mod, sparsifier, _ = self._get_model_and_sparsifier_and_sparse_config( tq.get_default_qat_qconfig("fbgemm")) tq.prepare_qat(mod, inplace=True) # need to setup sparse_config on new modules sparse_config = [ { "tensor_fqn": "5.weight", "sparsity_level": 0.7, "sparse_block_shape": (1, 4), "zeros_per_block": 4, }, { "tensor_fqn": "0.weight" }, ] sparsifier.prepare(mod, config=sparse_config) # check that correct modules had parametrizations added and # that none were lost during qat prepare self.assertTrue(hasattr(mod[0], "parametrizations")) self.assertTrue(hasattr(mod[5], "parametrizations")) # check that correct observers were inserted and that matching # occured successfully self.assertTrue(hasattr(mod[5], "activation_post_process")) self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear)) self._squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4)) # check that final module is the expected quantized module and that the model runs self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear)) self.assertEqual( mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0]) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
def test_qat_embedding_bag_errors(self): default_qat_qconfig = get_default_qat_qconfig( torch.backends.quantized.engine) # Test constructor parameters checks here. with self.assertRaisesRegex(AssertionError, "qconfig must be provided for QAT module"): nnqat.EmbeddingBag(10, 5, qconfig=None) with self.assertRaisesRegex( AssertionError, "Embedding Bag weights requires a qscheme of " + "torch.per_channel_affine_float_qparams"): nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig) # Test from_float checks here. embed = nn.Embedding(10, 5) with self.assertRaisesRegex( AssertionError, "qat.EmbeddingBag.from_float only works for EmbeddingBag"): nnqat.EmbeddingBag.from_float(embed) embed_bag = nn.EmbeddingBag(10, 5) with self.assertRaisesRegex( AssertionError, "Input float module must have qconfig defined"): nnqat.EmbeddingBag.from_float(embed_bag) embed_bag.qconfig = None with self.assertRaisesRegex( AssertionError, "Input float module must have a valid qconfig"): nnqat.EmbeddingBag.from_float(embed_bag) embed_bag.qconfig = default_qat_qconfig with self.assertRaisesRegex( AssertionError, "Embedding Bag weights requires a qscheme of " + "torch.per_channel_affine_float_qparams"): nnqat.EmbeddingBag.from_float(embed_bag)
def _get_model() -> torch.nn.Module: m = nn.Sequential(nn.Conv2d(2, 2, 1)) m.qconfig = get_default_qat_qconfig("fbgemm") m = prepare_qat(m) return m