def test_fusion_conv_with_bias(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ModelForFusionWithBias().train() # output with no fusion. out_ref = model(self.img_data_2d[0][0]) model.qconfig = QConfig(activation=torch.nn.Identity, weight=torch.nn.Identity) model = fuse_modules(model, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]) prep_model = prepare_qat(model, inplace=False) # output with fusion but no observers. out_fused = prep_model(self.img_data_2d[0][0]) self.assertEqual(out_ref, out_fused) model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) prepare_qat(model, inplace=True) model(self.img_data_2d[0][0]) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.conv2), nniqat.ConvBn2d) self.assertEqual(type(model.bn2), nn.Identity) checkQAT(model)
def test_dynamic_qat_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): # Dynamic QAT without memoryless observers should fail with self.assertRaisesRegex( ValueError, "Dynamic QAT requires a memoryless observer." + "This means a MovingAverage observer with averaging constant equal to 1" ): model = ManualLinearDynamicQATModel(default_qat_qconfig) model = prepare_qat( model, mapping={torch.nn.Linear: nnqatd.Linear}) model = ManualLinearDynamicQATModel() model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear}) self.assertEqual(type(model.fc1), nnqatd.Linear) self.assertEqual(type(model.fc2), nnqatd.Linear) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model, mapping={nnqatd.Linear: nnqd.Linear}) self.assertEqual(type(model.fc1), nnqd.Linear) self.assertEqual(type(model.fc2), nnqd.Linear) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model)
def test_s_prep_before_qat_prep(self): ( mod, sparsifier, sparse_config, ) = _get_model_and_sparsifier_and_sparse_config( tq.get_default_qat_qconfig("fbgemm") ) sparsifier.prepare(mod, config=sparse_config) tq.prepare_qat(mod, inplace=True) self.assertTrue(hasattr(mod[0], "parametrizations")) self.assertTrue(hasattr(mod[5], "parametrizations")) # check that correct observers were inserted and that matching # occured successfully self.assertTrue(hasattr(mod[5], "activation_post_process")) self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear)) _squash_mask_calibrate_and_convert( mod, sparsifier, torch.randn(1, 4, 4, 4) ) # check that final module is the expected quantized module and that the model runs self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear)) self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
def test_fusion_sequential_model_train(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ModelWithSequentialFusion().train() model.to(torch.float) fuse_modules_qat( model, [['conv1', 'relu1'] , ['features.0.0', 'features.0.1', 'features.0.2'], ['features.1.0', 'features.1.1', 'features.1.2'], ['features.2.0', 'features.2.1', 'features.2.2'], ['classifier.0', 'classifier.1']], inplace=True) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + Relu: nni.ConvReLU2d") self.assertEqual(type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d") self.assertEqual(type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu") self.assertEqual(type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity") for i in range(3): self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nni.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) prepare_qat(model, inplace=True) self.checkObservers(model) model(self.img_data_2d[0][0]) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvReLU2d) self.assertEqual(type(model.relu1), nn.Identity) for i in range(3): self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) checkQAT(model) model(self.img_data_2d[1][0]) convert(model, inplace=True) model(self.img_data_2d[1][0]) self.checkModelWithSequentialQuantized(model)
def test_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualEmbeddingBagLinear().train() model = prepare_qat( model, mapping=get_embedding_qat_module_mappings()) self.checkObservers(model) test_only_train_fn(model, self.embed_linear_data_train) # make sure not activation_post_process is inserted for EmbeddingBag self.assertFalse(hasattr(model, "activation_post_process")) # make sure that FakeQuant zero_points are correct dtype self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32) self.assertEqual( model.linear.weight_fake_quant.zero_point.dtype, torch.int32) model = convert( model, mapping=get_embedding_static_quant_module_mappings()) def checkQuantized(model): # Make sure EmbeddingBag is now a quantized EmbeddingBag. self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) # Also test that Linear has been quantized. self.assertTrue(type(model.linear), nnq.Linear) test_only_eval_fn(model, self.embed_data) self.checkScriptable(model, self.embed_data) self.checkNoQconfig(model) checkQuantized(model) model = ManualEmbeddingBagLinear()
def test_defused_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = DeFusedEmbeddingBagLinear().train() model = prepare_qat( model, mapping=get_embedding_qat_module_mappings()) self.checkObservers(model) test_only_train_fn(model, self.embed_linear_data_train) # make sure activation_post_process is inserted after Linear. self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize) # make sure that Embedding has a noop for activation. self.assertEqual(type(model.emb.activation_post_process), NoopObserver) model = convert( model, mapping=get_embedding_static_quant_module_mappings()) def checkQuantized(model): # make sure Embedding is now a QuantizedEmbedding self.assertEqual(type(model.emb), nn.quantized.Embedding) # make sure Linear is now a QuantizedLinear self.assertEqual(type(model.linear), nn.quantized.Linear) test_only_eval_fn(model, self.embed_data) self.checkScriptable(model, self.embed_data) self.checkNoQconfig(model) checkQuantized(model)
def test_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualEmbeddingBagLinear().train() model = prepare_qat(model) self.checkObservers(model) train_indices = [[ torch.randint(0, 10, (12, 12)), torch.randn((12, 1)) ] for _ in range(2)] eval_output = [[torch.randint(0, 10, (12, 1))]] test_only_train_fn(model, train_indices) # make sure not activation_post_process is inserted for EmbeddingBag self.assertFalse(hasattr(model, "activation_post_process")) model = convert(model) def checkQuantized(model): # Make sure EmbeddingBag is now a quantized EmbeddingBag. self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) # Also test that Linear has been quantized. self.assertTrue(type(model.linear), nnq.Linear) test_only_eval_fn(model, eval_output) self.checkScriptable(model, eval_output) self.checkNoQconfig(model) checkQuantized(model) model = ManualEmbeddingBagLinear() model = quantize_qat(model, test_only_train_fn, [train_indices]) checkQuantized(model)
def test_conv_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualConvLinearQATModel() model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.img_data_2d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv), nnq.Conv2d) self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.img_data_2d) self.checkScriptable(model, self.img_data_2d) self.checkNoQconfig(model) checkQuantized(model) model = ManualConvLinearQATModel() model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) checkQuantized(model)
def test_fuse_module_train(self): model = ModelForFusion(default_qat_qconfig).train() # Test step by step fusion model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn']) self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, msg="Fused Conv + BN + Relu first layer") self.assertEqual(type(model.bn1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped BN)") self.assertEqual(type(model.relu1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped Relu)") self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN") self.assertEqual(type(model.sub1.bn), torch.nn.Identity, msg="Fused submodule Conv + BN (skipped BN)") self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU") model = prepare_qat(model) self.checkObservers(model) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) checkQAT(model) test_only_train_fn(model, self.img_data_1d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) test_only_eval_fn(model, self.img_data_1d) self.checkNoQconfig(model) with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): checkQuantized(model) model = ModelForFusion(default_qat_qconfig).train() model = fuse_modules_qat( model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): checkQuantized(model)
def test_fusion_conv_with_bias(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model_orig = ModelForFusionWithBias().train() # reference model model_ref = copy.deepcopy(model_orig) # output with no fusion. out_ref = model_ref(self.img_data_2d[0][0]) # fused model model_orig.qconfig = QConfig(activation=torch.nn.Identity, weight=torch.nn.Identity) model = fuse_modules_qat( model_orig, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]) prep_model = prepare_qat(model, inplace=False) # output with fusion but no observers. out_fused = prep_model(self.img_data_2d[0][0]) self.assertEqual(out_ref, out_fused) def checkBN(bn_ref, bn): self.assertEqual(bn_ref.weight, bn.weight) self.assertEqual(bn_ref.bias, bn.bias) self.assertEqual(bn_ref.running_mean, bn.running_mean) self.assertEqual(bn_ref.running_var, bn.running_var) checkBN(model_ref.bn1, prep_model.conv1.bn) checkBN(model_ref.bn2, prep_model.conv2.bn) model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) prepare_qat(model, inplace=True) model(self.img_data_2d[0][0]) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.conv2), nniqat.ConvBn2d) self.assertEqual(type(model.bn2), nn.Identity) checkQAT(model)
def test_embedding_qat_qconfig_equal(self): # Embedding QAT uses a NoopObserver class for activation, # and a FakeQuant for weight, make sure that qconfig comparison # functions properly for a mix of partial function and class in # qconfig. model = ManualEmbeddingBagLinear().train() model = prepare_qat(model) self.assertTrue( qconfig_equals(model.emb.qconfig, default_embedding_qat_qconfig))
def test_train_save_load_eval(self): r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict During eval, we first call prepare_qat and conver on the model and then load the state_dict and compare results against original model """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = TwoLayerLinearModel() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qat_qconfig( qengine) model = prepare_qat(model) fq_state_dict = model.state_dict() test_only_train_fn(model, self.train_data) model = convert(model) quant_state_dict = model.state_dict() x = torch.rand(2, 5, dtype=torch.float) ref = model(x) # Create model again for eval. Check result using quantized state_dict model = TwoLayerLinearModel() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qat_qconfig( qengine) torch.ao.quantization.prepare_qat(model, inplace=True) new_state_dict = model.state_dict() # Check to make sure the model after prepare_qat has the same state_dict as original. self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) torch.ao.quantization.convert(model, inplace=True) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out) # Check model created using prepare has same state dict as quantized state_dict model = TwoLayerLinearModel() model.eval() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qconfig( qengine) torch.ao.quantization.prepare(model, inplace=True) torch.ao.quantization.convert(model, inplace=True) self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out)
def test_qat_prep_before_s_prep(self): mod, sparsifier, _ = self._get_model_and_sparsifier_and_sparse_config( tq.get_default_qat_qconfig("fbgemm")) tq.prepare_qat(mod, inplace=True) # need to setup sparse_config on new modules sparse_config = [ { "tensor_fqn": "5.weight", "sparsity_level": 0.7, "sparse_block_shape": (1, 4), "zeros_per_block": 4, }, { "tensor_fqn": "0.weight" }, ] sparsifier.prepare(mod, config=sparse_config) # check that correct modules had parametrizations added and # that none were lost during qat prepare self.assertTrue(hasattr(mod[0], "parametrizations")) self.assertTrue(hasattr(mod[5], "parametrizations")) # check that correct observers were inserted and that matching # occured successfully self.assertTrue(hasattr(mod[5], "activation_post_process")) self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear)) self._squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4)) # check that final module is the expected quantized module and that the model runs self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear)) self.assertEqual( mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0]) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
def test_fixed_qparam_ops(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.sigmoid = torch.nn.Sigmoid() self.hardsigmoid = torch.nn.Hardsigmoid() self.tanh = torch.nn.Tanh() self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.sigmoid(x) x = self.hardsigmoid(x) x = self.tanh(x) x = self.dequant(x) return x m = M().train() m.qconfig = default_qat_qconfig m = prepare_qat(m) for attr in ['sigmoid', 'hardsigmoid', 'tanh']: self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize) data = torch.randn(1, 3, 2, 4) before_convert = m(data) m = convert(m) after_convert = m(data) self.assertEqual(before_convert, after_convert) # make sure activation post process is removed for attr in ['sigmoid', 'hardsigmoid', 'tanh']: # verify fake quant module is removd self.assertFalse( hasattr(getattr(m, attr), 'activation_post_process')) # verify that hooks are removed self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) # make sure no fake quantize module is inserted for eval mode def checkNoFQModule(m): for attr in ['sigmoid', 'hardsigmoid', 'tanh']: self.assertFalse( hasattr(getattr(m, attr), "activation_post_process")) self.assertTrue( len(getattr(m, attr)._forward_hooks.items()) == 0) m = M().eval() m.qconfig = default_qconfig m = prepare(m) checkNoFQModule(m) m = convert(m) checkNoFQModule(m)
def test_eval_only_fake_quant(self): r"""Using FakeQuant in evaluation only mode, this is useful for estimating accuracy loss when we quantize the network """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualLinearQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) model.eval() test_only_eval_fn(model, self.calib_data)
def test_linear_bn_workflow(self): qengine = torch.backends.quantized.engine m = nn.Sequential( QuantStub(), nn.Linear(4, 4), nn.BatchNorm1d(4), ) data = torch.randn(4, 4) m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']]) mp = prepare_qat(m) mp(data) mq = convert(mp) self.assertTrue(type(mq[1]) == nnq.Linear) self.assertTrue(type(mq[2]) == nn.Identity)
def test_relu(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.relu = nn.ReLU() def forward(self, x): x = self.relu(x) return x m = M().train() m.qconfig = default_qconfig m = prepare_qat(m) # make sure no activation_post_process is inserted for relu self.assertFalse(hasattr(m, "activation_post_process")) m = convert(m) # make sure ReLU module is not changed self.assertTrue(type(m.relu), nn.ReLU)
def test_forward_hooks_preserved(self): r"""Test QAT on preserving pre forward and post forward hooks of original model """ qengine = torch.backends.quantized.engine model = QuantStubModel() counter = { 'pre_forwards': 0, 'forwards': 0, } def fw_pre_hook(h_module, input): counter['pre_forwards'] += 1 def fw_hook(h_module, input, output): counter['forwards'] += 1 model.fc.register_forward_pre_hook(fw_pre_hook) model.fc.register_forward_hook(fw_hook) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) model = prepare_qat(model) def checkHooksIsPresent(model, before_convert=True): forward_hooks = 1 if before_convert: self.assertEqual(len(model.quant._forward_hooks.values()), 1, "Quantization observer hook has disappeared") forward_hooks = 2 self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) self.assertEqual( len(model.fc._forward_pre_hooks.values()), 1, "Extra pre forward hooks have appeared on a layer") self.assertEqual( len(model.fc._forward_hooks.values()), forward_hooks, "Extra post forward hooks have appeared on a layer") checkHooksIsPresent(model, True) x = torch.rand(2, 5, dtype=torch.float) model(x) torch.ao.quantization.convert(model, inplace=True) checkHooksIsPresent(model, False)
def test_dropout(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualDropoutQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.dropout), nnq.Dropout) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model) checkQuantized(model) model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn, [self.train_data]) checkQuantized(model)
def _test_activation_convert_numerics_impl(self, Act, data): class M(torch.nn.Module): def __init__(self): super().__init__() self.act = Act() self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.act(x) x = self.dequant(x) return x m = M().train() m.qconfig = default_qat_qconfig m = prepare_qat(m) before_convert = m(data) m = convert(m) after_convert = m(data) self.assertEqual(before_convert, after_convert)
def _get_model() -> torch.nn.Module: m = nn.Sequential(nn.Conv2d(2, 2, 1)) m.qconfig = get_default_qat_qconfig("fbgemm") m = prepare_qat(m) return m