def test_conv_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualConvLinearQATModel() model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.img_data_2d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv), nnq.Conv2d) self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.img_data_2d) self.checkScriptable(model, self.img_data_2d) self.checkNoQconfig(model) checkQuantized(model) model = ManualConvLinearQATModel() model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) checkQuantized(model)
def test_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualEmbeddingBagLinear().train() model = prepare_qat( model, mapping=get_embedding_qat_module_mappings()) self.checkObservers(model) test_only_train_fn(model, self.embed_linear_data_train) # make sure not activation_post_process is inserted for EmbeddingBag self.assertFalse(hasattr(model, "activation_post_process")) # make sure that FakeQuant zero_points are correct dtype self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32) self.assertEqual( model.linear.weight_fake_quant.zero_point.dtype, torch.int32) model = convert( model, mapping=get_embedding_static_quant_module_mappings()) def checkQuantized(model): # Make sure EmbeddingBag is now a quantized EmbeddingBag. self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) # Also test that Linear has been quantized. self.assertTrue(type(model.linear), nnq.Linear) test_only_eval_fn(model, self.embed_data) self.checkScriptable(model, self.embed_data) self.checkNoQconfig(model) checkQuantized(model) model = ManualEmbeddingBagLinear()
def test_defused_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = DeFusedEmbeddingBagLinear().train() model = prepare_qat( model, mapping=get_embedding_qat_module_mappings()) self.checkObservers(model) test_only_train_fn(model, self.embed_linear_data_train) # make sure activation_post_process is inserted after Linear. self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize) # make sure that Embedding has a noop for activation. self.assertEqual(type(model.emb.activation_post_process), NoopObserver) model = convert( model, mapping=get_embedding_static_quant_module_mappings()) def checkQuantized(model): # make sure Embedding is now a QuantizedEmbedding self.assertEqual(type(model.emb), nn.quantized.Embedding) # make sure Linear is now a QuantizedLinear self.assertEqual(type(model.linear), nn.quantized.Linear) test_only_eval_fn(model, self.embed_data) self.checkScriptable(model, self.embed_data) self.checkNoQconfig(model) checkQuantized(model)
def test_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualEmbeddingBagLinear().train() model = prepare_qat(model) self.checkObservers(model) train_indices = [[ torch.randint(0, 10, (12, 12)), torch.randn((12, 1)) ] for _ in range(2)] eval_output = [[torch.randint(0, 10, (12, 1))]] test_only_train_fn(model, train_indices) # make sure not activation_post_process is inserted for EmbeddingBag self.assertFalse(hasattr(model, "activation_post_process")) model = convert(model) def checkQuantized(model): # Make sure EmbeddingBag is now a quantized EmbeddingBag. self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) # Also test that Linear has been quantized. self.assertTrue(type(model.linear), nnq.Linear) test_only_eval_fn(model, eval_output) self.checkScriptable(model, eval_output) self.checkNoQconfig(model) checkQuantized(model) model = ManualEmbeddingBagLinear() model = quantize_qat(model, test_only_train_fn, [train_indices]) checkQuantized(model)
def test_dynamic_qat_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): # Dynamic QAT without memoryless observers should fail with self.assertRaisesRegex( ValueError, "Dynamic QAT requires a memoryless observer." + "This means a MovingAverage observer with averaging constant equal to 1" ): model = ManualLinearDynamicQATModel(default_qat_qconfig) model = prepare_qat( model, mapping={torch.nn.Linear: nnqatd.Linear}) model = ManualLinearDynamicQATModel() model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear}) self.assertEqual(type(model.fc1), nnqatd.Linear) self.assertEqual(type(model.fc2), nnqatd.Linear) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model, mapping={nnqatd.Linear: nnqd.Linear}) self.assertEqual(type(model.fc1), nnqd.Linear) self.assertEqual(type(model.fc2), nnqd.Linear) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model)
def test_fuse_module_train(self): model = ModelForFusion(default_qat_qconfig).train() # Test step by step fusion model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn']) self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, msg="Fused Conv + BN + Relu first layer") self.assertEqual(type(model.bn1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped BN)") self.assertEqual(type(model.relu1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped Relu)") self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN") self.assertEqual(type(model.sub1.bn), torch.nn.Identity, msg="Fused submodule Conv + BN (skipped BN)") self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU") model = prepare_qat(model) self.checkObservers(model) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) checkQAT(model) test_only_train_fn(model, self.img_data_1d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) test_only_eval_fn(model, self.img_data_1d) self.checkNoQconfig(model) with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): checkQuantized(model) model = ModelForFusion(default_qat_qconfig).train() model = fuse_modules_qat( model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): checkQuantized(model)
def test_train_save_load_eval(self): r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict During eval, we first call prepare_qat and conver on the model and then load the state_dict and compare results against original model """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = TwoLayerLinearModel() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qat_qconfig( qengine) model = prepare_qat(model) fq_state_dict = model.state_dict() test_only_train_fn(model, self.train_data) model = convert(model) quant_state_dict = model.state_dict() x = torch.rand(2, 5, dtype=torch.float) ref = model(x) # Create model again for eval. Check result using quantized state_dict model = TwoLayerLinearModel() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qat_qconfig( qengine) torch.ao.quantization.prepare_qat(model, inplace=True) new_state_dict = model.state_dict() # Check to make sure the model after prepare_qat has the same state_dict as original. self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) torch.ao.quantization.convert(model, inplace=True) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out) # Check model created using prepare has same state dict as quantized state_dict model = TwoLayerLinearModel() model.eval() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qconfig( qengine) torch.ao.quantization.prepare(model, inplace=True) torch.ao.quantization.convert(model, inplace=True) self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out)
def test_dropout(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualDropoutQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.dropout), nnq.Dropout) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model) checkQuantized(model) model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn, [self.train_data]) checkQuantized(model)