def test_conv_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualConvLinearQATModel() model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.img_data_2d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv), nnq.Conv2d) self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.img_data_2d) self.checkScriptable(model, self.img_data_2d) self.checkNoQconfig(model) checkQuantized(model) model = ManualConvLinearQATModel() model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) checkQuantized(model)
def test_qat_convbn_fused_syncbn_replacement(self): """ Tests that SyncBatchNorm replacement works for fused ConvBN. """ if 'fbgemm' not in torch.backends.quantized.supported_engines: return with override_quantized_engine('fbgemm'): # create conv-bn class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv = nn.Conv2d(4, 1, 3, padding=1) self.bn = nn.BatchNorm2d(1) def forward(self, x): x = self.conv(x) x = self.bn(x) return x model = Model() # fuse it fused_model = torch.quantization.fuse_modules( model, [['conv', 'bn']], ) # convert to QAT fused_model.qconfig = torch.quantization.get_default_qconfig( 'fbgemm') torch.quantization.prepare_qat(fused_model, inplace=True) # replace with DDP fused_model = nn.SyncBatchNorm.convert_sync_batchnorm(fused_model) self.assertTrue(isinstance(fused_model.conv.bn, nn.SyncBatchNorm), "Expected BN to be converted to SyncBN")
def test_fusion_conv_with_bias(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ModelForFusionWithBias().train() # output with no fusion. out_ref = model(self.img_data_2d[0][0]) model.qconfig = QConfig(activation=torch.nn.Identity, weight=torch.nn.Identity) model = fuse_modules(model, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]) prep_model = prepare_qat(model, inplace=False) # output with fusion but no observers. out_fused = prep_model(self.img_data_2d[0][0]) self.assertEqual(out_ref, out_fused) model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) prepare_qat(model, inplace=True) model(self.img_data_2d[0][0]) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.conv2), nniqat.ConvBn2d) self.assertEqual(type(model.bn2), nn.Identity) checkQAT(model)
def test_defused_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = DeFusedEmbeddingBagLinear().train() model = prepare_qat( model, mapping=get_embedding_qat_module_mappings()) self.checkObservers(model) test_only_train_fn(model, self.embed_linear_data_train) # make sure activation_post_process is inserted after Linear. self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize) # make sure that Embedding has a noop for activation. self.assertEqual(type(model.emb.activation_post_process), NoopObserver) model = convert( model, mapping=get_embedding_static_quant_module_mappings()) def checkQuantized(model): # make sure Embedding is now a QuantizedEmbedding self.assertEqual(type(model.emb), nn.quantized.Embedding) # make sure Linear is now a QuantizedLinear self.assertEqual(type(model.linear), nn.quantized.Linear) test_only_eval_fn(model, self.embed_data) self.checkScriptable(model, self.embed_data) self.checkNoQconfig(model) checkQuantized(model)
def test_conv1d_api( self, batch_size, in_channels_per_group, L, out_channels_per_group, groups, kernel, stride, pad, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, qengine, ): # Tests the correctness of the conv1d function. if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return use_channelwise = False input_feature_map_size = (L, ) kernel_size = (kernel, ) stride = (stride, ) padding = (pad, ) dilation = (dilation, ) with override_quantized_engine(qengine): qconv_fn = qF.conv1d conv_fn = F.conv1d self._test_conv_api_impl( qconv_fn, conv_fn, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise)
def test_fusion_sequential_model_eval(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ModelWithSequentialFusion().eval() model.to(torch.float) fuse_modules(model, [['conv1', 'relu1'] , ['features.0.0', 'features.0.1', 'features.0.2'], ['features.1.0', 'features.1.1', 'features.1.2'], ['features.2.0', 'features.2.1', 'features.2.2'], ['classifier.0', 'classifier.1']], inplace=True) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + Relu: nni.ConvReLU2d") self.assertEqual(type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d") self.assertEqual(type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu") self.assertEqual(type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity") for i in range(3): self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nni.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) prepare(model, inplace=True) self.checkObservers(model) model(self.img_data_2d[0][0]) convert(model, inplace=True) model(self.img_data_2d[1][0]) self.checkModelWithSequentialQuantized(model)
def test_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualEmbeddingBagLinear().train() model = prepare_qat(model) self.checkObservers(model) train_indices = [[ torch.randint(0, 10, (12, 12)), torch.randn((12, 1)) ] for _ in range(2)] eval_output = [[torch.randint(0, 10, (12, 1))]] test_only_train_fn(model, train_indices) # make sure not activation_post_process is inserted for EmbeddingBag self.assertFalse(hasattr(model, "activation_post_process")) model = convert(model) def checkQuantized(model): # Make sure EmbeddingBag is now a quantized EmbeddingBag. self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) # Also test that Linear has been quantized. self.assertTrue(type(model.linear), nnq.Linear) test_only_eval_fn(model, eval_output) self.checkScriptable(model, eval_output) self.checkNoQconfig(model) checkQuantized(model) model = ManualEmbeddingBagLinear() model = quantize_qat(model, test_only_train_fn, [train_indices]) checkQuantized(model)
def test_conv3d_api( self, batch_size, in_channels_per_group, D, H, W, out_channels_per_group, groups, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, qengine, ): # Tests the correctness of the conv3d function. # Currently conv3d only supports FbGemm engine if qengine not in torch.backends.quantized.supported_engines: return input_feature_map_size = (D, H, W) kernel_size = (kernel_d, kernel_h, kernel_w) stride = (stride_d, stride_h, stride_w) padding = (pad_d, pad_h, pad_w) dilation = (dilation, dilation, dilation) with override_quantized_engine(qengine): qconv_fn = qF.conv3d conv_fn = F.conv3d self._test_conv_api_impl( qconv_fn, conv_fn, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise)
def test_compare_model_stub(self): r"""Compare the output of quantized conv layer and its float shadow module """ def compare_and_validate_results(float_model, q_model, module_swap_list, data): ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data, ShadowLogger) self.assertEqual(len(ob_dict), 1) for k, v in ob_dict.items(): self.assertTrue(v["float"].shape == v["quantized"].shape) for qengine in supported_qengines: with override_quantized_engine(qengine): model_list = [ AnnotatedConvModel(qengine), AnnotatedConvBnReLUModel(qengine), ] data = self.img_data[0][0] module_swap_list = [ nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d ] for model in model_list: model.eval() if hasattr(model, "fuse_model"): model.fuse_model() q_model = quantize(model, default_eval_fn, self.img_data) compare_and_validate_results(model, q_model, module_swap_list, data) # Test adding stub to sub module model = ModelWithSubModules().eval() q_model = quantize(model, default_eval_fn, self.img_data) module_swap_list = [SubModule] ob_dict = compare_model_stub(model, q_model, module_swap_list, data, ShadowLogger) self.assertTrue(isinstance(q_model.mod1, Shadow)) self.assertFalse(isinstance(q_model.conv, Shadow)) for k, v in ob_dict.items(): torch.testing.assert_allclose(v["float"], v["quantized"].dequantize()) # Test adding stub to functionals model = ModelWithFunctionals().eval() model.qconfig = torch.quantization.get_default_qconfig( "fbgemm") q_model = prepare(model, inplace=False) q_model(data) q_model = convert(q_model) module_swap_list = [nnq.FloatFunctional] ob_dict = compare_model_stub(model, q_model, module_swap_list, data, ShadowLogger) self.assertEqual(len(ob_dict), 6) self.assertTrue(isinstance(q_model.mycat, Shadow)) self.assertTrue(isinstance(q_model.myadd, Shadow)) self.assertTrue(isinstance(q_model.mymul, Shadow)) self.assertTrue(isinstance(q_model.myadd_relu, Shadow)) self.assertTrue(isinstance(q_model.my_scalar_add, Shadow)) self.assertTrue(isinstance(q_model.my_scalar_mul, Shadow)) for k, v in ob_dict.items(): self.assertTrue(v["float"].shape == v["quantized"].shape)
def test_embedding_bag_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualEmbeddingBagLinear().train() model = prepare_qat( model, mapping=get_embedding_qat_module_mappings()) self.checkObservers(model) test_only_train_fn(model, self.embed_linear_data_train) # make sure not activation_post_process is inserted for EmbeddingBag self.assertFalse(hasattr(model, "activation_post_process")) # make sure that FakeQuant zero_points are correct dtype self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32) self.assertEqual( model.linear.weight_fake_quant.zero_point.dtype, torch.int32) model = convert( model, mapping=get_embedding_static_quant_module_mappings()) def checkQuantized(model): # Make sure EmbeddingBag is now a quantized EmbeddingBag. self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) # Also test that Linear has been quantized. self.assertTrue(type(model.linear), nnq.Linear) test_only_eval_fn(model, self.embed_data) self.checkScriptable(model, self.embed_data) self.checkNoQconfig(model) checkQuantized(model) model = ManualEmbeddingBagLinear()
def test_dynamic_qat_linear(self): for qengine in supported_qengines: with override_quantized_engine(qengine): # Dynamic QAT without memoryless observers should fail with self.assertRaisesRegex( ValueError, "Dynamic QAT requires a memoryless observer." + "This means a MovingAverage observer with averaging constant equal to 1" ): model = ManualLinearDynamicQATModel(default_qat_qconfig) model = prepare_qat( model, mapping={torch.nn.Linear: nnqatd.Linear}) model = ManualLinearDynamicQATModel() model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear}) self.assertEqual(type(model.fc1), nnqatd.Linear) self.assertEqual(type(model.fc2), nnqatd.Linear) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model, mapping={nnqatd.Linear: nnqd.Linear}) self.assertEqual(type(model.fc1), nnqd.Linear) self.assertEqual(type(model.fc2), nnqd.Linear) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model)
def test_weight_only_activation_only_fakequant(self): for qengine in supported_qengines: with override_quantized_engine(qengine): torch.manual_seed(67) calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) qconfigset = set([ torch.quantization.default_weight_only_qconfig, torch.quantization.default_activation_only_qconfig ]) SQNRTarget = [35, 45] for idx, qconfig in enumerate(qconfigset): my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) my_model.eval() out_ref = my_model(eval_data) fq_model = torch.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = qconfig torch.ao.quantization.fuse_modules( fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) fq_model(calib_data) fq_model.apply(torch.quantization.enable_fake_quant) fq_model.apply(torch.quantization.disable_observer) out_fq = fq_model(eval_data) SQNRdB = 20 * torch.log10( torch.norm(out_ref) / torch.norm(out_ref - out_fq)) self.assertGreater( SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float')
def _compare_script_and_mobile(self, model: torch.nn.Module, input: torch.Tensor): qengine = "qnnpack" with override_quantized_engine(qengine): script_module = torch.jit.script(model) script_module_result = script_module(input) max_retry = 5 for retry in range(1, max_retry + 1): # retires `max_retry` times; breaks iff succeeds else throws exception try: buffer = io.BytesIO( script_module._save_to_buffer_for_lite_interpreter()) buffer.seek(0) mobile_module = _load_for_lite_interpreter(buffer) mobile_module_result = mobile_module(input) torch.testing.assert_allclose(script_module_result, mobile_module_result) mobile_module_forward_result = mobile_module.forward(input) torch.testing.assert_allclose( script_module_result, mobile_module_forward_result) mobile_module_run_method_result = mobile_module.run_method( "forward", input) torch.testing.assert_allclose( script_module_result, mobile_module_run_method_result) except AssertionError as e: if retry == max_retry: raise e else: continue break
def test_float_quant_compare_per_tensor(self): for qengine in supported_qengines: with override_quantized_engine(qengine): torch.manual_seed(42) my_model = ModelMultipleOps().to(torch.float32) my_model.eval() calib_data = torch.rand(1024, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(1, 3, 15, 15, dtype=torch.float32) out_ref = my_model(eval_data) qModel = torch.quantization.QuantWrapper(my_model) qModel.eval() qModel.qconfig = torch.quantization.default_qconfig torch.ao.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare(qModel, inplace=True) qModel(calib_data) torch.quantization.convert(qModel, inplace=True) out_q = qModel(eval_data) SQNRdB = 20 * torch.log10( torch.norm(out_ref) / torch.norm(out_ref - out_q)) # Quantized model output should be close to floating point model output numerically # Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired # output self.assertGreater( SQNRdB, 30, msg= 'Quantized model numerics diverge from float, expect SQNR > 30 dB' )
def test_linear_dynamic(self): for i, qengine in enumerate(supported_qengines): with override_quantized_engine(qengine): module_qint8 = nnqd.Linear(3, 1, bias_=True, dtype=torch.qint8) self._test_op(module_qint8, "qint8", input_size=[1, 3], input_quantized=False, generate=False, iter=i) if qengine == 'fbgemm': module_float16 = nnqd.Linear(3, 1, bias_=True, dtype=torch.float16) self._test_op(module_float16, "float16", input_size=[1, 3], input_quantized=False, generate=False)
def _create_quantized_model(self, model_class: Type[torch.nn.Module], **kwargs): qengine = "qnnpack" with override_quantized_engine(qengine): qconfig = torch.quantization.get_default_qconfig(qengine) model = model_class(**kwargs) model = quantize(model, test_only_eval_fn, [self.calib_data]) return model
def test_train_save_load_eval(self): r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict During eval, we first call prepare_qat and conver on the model and then load the state_dict and compare results against original model """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = TwoLayerLinearModel() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qat_qconfig( qengine) model = prepare_qat(model) fq_state_dict = model.state_dict() test_only_train_fn(model, self.train_data) model = convert(model) quant_state_dict = model.state_dict() x = torch.rand(2, 5, dtype=torch.float) ref = model(x) # Create model again for eval. Check result using quantized state_dict model = TwoLayerLinearModel() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qat_qconfig( qengine) torch.ao.quantization.prepare_qat(model, inplace=True) new_state_dict = model.state_dict() # Check to make sure the model after prepare_qat has the same state_dict as original. self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) torch.ao.quantization.convert(model, inplace=True) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out) # Check model created using prepare has same state dict as quantized state_dict model = TwoLayerLinearModel() model.eval() model = torch.ao.quantization.QuantWrapper(model) model.qconfig = torch.ao.quantization.get_default_qconfig( qengine) torch.ao.quantization.prepare(model, inplace=True) torch.ao.quantization.convert(model, inplace=True) self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) model.eval() model.load_state_dict(quant_state_dict) out = model(x) self.assertEqual(ref, out)
def test_compare_model_outputs(self): r"""Compare the output of conv layer in quantized model and corresponding output of conv layer in float model """ def compare_and_validate_results(float_model, q_model, data): act_compare_dict = compare_model_outputs(float_model, q_model, data) self.assertEqual(len(act_compare_dict), 2) expected_act_compare_dict_keys = {"conv.stats", "quant.stats"} self.assertTrue( act_compare_dict.keys() == expected_act_compare_dict_keys) for k, v in act_compare_dict.items(): self.assertTrue(v["float"].shape == v["quantized"].shape) for qengine in supported_qengines: with override_quantized_engine(qengine): model_list = [ AnnotatedConvModel(qengine), AnnotatedConvBnReLUModel(qengine), ] data = self.img_data[0][0] module_swap_list = [ nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d ] for model in model_list: model.eval() if hasattr(model, "fuse_model"): model.fuse_model() q_model = quantize(model, default_eval_fn, self.img_data) compare_and_validate_results(model, q_model, data) # Test functionals model = ModelWithFunctionals().eval() model.qconfig = torch.quantization.get_default_qconfig( "fbgemm") q_model = prepare(model, inplace=False) q_model(data) q_model = convert(q_model) act_compare_dict = compare_model_outputs(model, q_model, data) self.assertEqual(len(act_compare_dict), 7) expected_act_compare_dict_keys = { "mycat.stats", "myadd.stats", "mymul.stats", "myadd_relu.stats", "my_scalar_add.stats", "my_scalar_mul.stats", "quant.stats", } self.assertTrue( act_compare_dict.keys() == expected_act_compare_dict_keys) for k, v in act_compare_dict.items(): self.assertTrue(v["float"].shape == v["quantized"].shape)
def test_module_with_shared_type_instances(self): class Child(nn.Module): def __init__(self): super(Child, self).__init__() self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32) def forward(self, x): x = self.conv1(x) return x class Parent(nn.Module): def __init__(self): super(Parent, self).__init__() self.quant = torch.quantization.QuantStub() self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32) self.child = Child() self.child2 = Child() self.dequant = torch.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.child(x) x = self.child2(x) x = self.dequant(x) return x def _static_quant(model): qModel = torch.quantization.QuantWrapper(model) qModel.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(qModel, inplace=True) qModel(torch.rand(4, 1, 4, 4, dtype=torch.float32)) torch.quantization.convert(qModel, inplace=True) return model with override_quantized_engine('fbgemm'): data = torch.randn(4, 1, 4, 4, dtype=torch.float32) m = Parent().to(torch.float32) m = _static_quant(m) m = torch.jit.script(m) m.eval() torch._C._jit_pass_inline(m.graph) m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c)) # Earlier bug resulted in _packed_params set to false. FileCheck().check_not('_packed_params = False').run( m_frozen._c.dump_to_str(True, True, False)) m_res = m(data) # It used to segfault while running frozen module. m_frozen_res = m_frozen(data) self.assertEqual(m_res, m_frozen_res)
def test_eval_only_fake_quant(self): r"""Using FakeQuant in evaluation only mode, this is useful for estimating accuracy loss when we quantize the network """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualLinearQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) model.eval() test_only_eval_fn(model, self.calib_data)
def test_fake_quant_true_quant_compare(self): for qengine in ["fbgemm", "qnnpack"]: if qengine not in torch.backends.quantized.supported_engines: continue if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: continue with override_quantized_engine(qengine): torch.manual_seed(67) my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) my_model.eval() out_ref = my_model(eval_data) fq_model = torch.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = torch.quantization.default_qat_qconfig torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) fq_model(calib_data) fq_model.apply(torch.quantization.enable_fake_quant) fq_model.apply(torch.quantization.disable_observer) out_fq = fq_model(eval_data) SQNRdB = 20 * torch.log10( torch.norm(out_ref) / torch.norm(out_ref - out_fq)) # Quantized model output should be close to floating point model output numerically # Setting target SQNR to be 35 dB self.assertGreater( SQNRdB, 35, msg= 'Quantized model numerics diverge from float, expect SQNR > 35 dB' ) torch.quantization.convert(fq_model) out_q = fq_model(eval_data) SQNRdB = 20 * torch.log10( torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)) self.assertGreater( SQNRdB, 60, msg= 'Fake quant and true quant numerics diverge, expect SQNR > 60 dB' )
def test_conv2d_api( self, batch_size, in_channels_per_group, H, W, out_channels_per_group, groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise, qengine, ): # Tests the correctness of the conv2d module. if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': if IS_PPC or TEST_WITH_UBSAN: return use_channelwise = False in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups input_feature_map_size = (H, W) kernel_size = (kernel_h, kernel_w) stride = (stride_h, stride_w) padding = (pad_h, pad_w) dilation = (dilation, dilation) with override_quantized_engine(qengine): if use_fused: module_name = "QuantizedConvReLU2d" qconv_module = nnq_fused.ConvReLU2d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") else: module_name = "QuantizedConv2d" qconv_module = nnq.Conv2d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") conv_module = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode="zeros") if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU2d(conv_module, relu_module) conv_module = conv_module.float() self._test_conv_api_impl( module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise)
def test_fusion_conv_with_bias(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model_orig = ModelForFusionWithBias().train() # reference model model_ref = copy.deepcopy(model_orig) # output with no fusion. out_ref = model_ref(self.img_data_2d[0][0]) # fused model model_orig.qconfig = QConfig(activation=torch.nn.Identity, weight=torch.nn.Identity) model = fuse_modules_qat( model_orig, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]) prep_model = prepare_qat(model, inplace=False) # output with fusion but no observers. out_fused = prep_model(self.img_data_2d[0][0]) self.assertEqual(out_ref, out_fused) def checkBN(bn_ref, bn): self.assertEqual(bn_ref.weight, bn.weight) self.assertEqual(bn_ref.bias, bn.bias) self.assertEqual(bn_ref.running_mean, bn.running_mean) self.assertEqual(bn_ref.running_var, bn.running_var) checkBN(model_ref.bn1, prep_model.conv1.bn) checkBN(model_ref.bn2, prep_model.conv2.bn) model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) prepare_qat(model, inplace=True) model(self.img_data_2d[0][0]) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.conv2), nniqat.ConvBn2d) self.assertEqual(type(model.bn2), nn.Identity) checkQAT(model)
def test_record_observer(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = AnnotatedSingleLayerLinearModel() model.qconfig = default_debug_qconfig model = prepare(model) # run the evaluation and dump all tensors test_only_eval_fn(model, self.calib_data) test_only_eval_fn(model, self.calib_data) observer_dict = {} get_observer_dict(model, observer_dict) self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(), 'observer is not recorded in the dict') self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()), 2 * len(self.calib_data)) self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0], model(self.calib_data[0][0]))
def test_qat_data_parallel(self): """ Tests that doing QAT in nn.DataParallel does not crash. """ if 'fbgemm' not in torch.backends.quantized.supported_engines: return with override_quantized_engine('fbgemm'): device = torch.device('cuda') model = nn.Sequential( torch.quantization.QuantStub(), nn.Conv2d(3, 1, 1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), nn.Conv2d(1, 2, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(2), nn.AvgPool2d(14), nn.Sigmoid(), torch.quantization.DeQuantStub(), ) torch.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True) model.qconfig = torch.quantization.get_default_qat_qconfig( 'fbgemm') torch.quantization.prepare_qat(model, inplace=True) model = nn.DataParallel(model, device_ids=[0, 1]) model.to(device) model.train() for epoch in range(3): inputs = torch.rand(2, 3, 28, 28).to(device) model(inputs) if epoch >= 1: model.apply(torch.quantization.disable_observer) if epoch >= 2: model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) quant_model = copy.deepcopy(model.module) quant_model = torch.quantization.convert( quant_model.eval().cpu(), inplace=False) with torch.no_grad(): out = quant_model(torch.rand(1, 3, 28, 28))
def test_quantized_conv_no_asan_failures(self): # There were ASAN failures when fold_conv_bn was run on # already quantized conv modules. Verifying that this does # not happen again. if 'qnnpack' not in torch.backends.quantized.supported_engines: return class Child(nn.Module): def __init__(self): super(Child, self).__init__() self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv2(x) return x class Parent(nn.Module): def __init__(self): super(Parent, self).__init__() self.quant = torch.ao.quantization.QuantStub() self.conv1 = nn.Conv2d(1, 1, 1) self.child = Child() self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.child(x) x = self.dequant(x) return x with override_quantized_engine('qnnpack'): model = Parent() model.qconfig = torch.ao.quantization.get_default_qconfig( 'qnnpack') torch.ao.quantization.prepare(model, inplace=True) model(torch.randn(4, 1, 4, 4)) torch.ao.quantization.convert(model, inplace=True) model = torch.jit.script(model) # this line should not have ASAN failures model_optim = optimize_for_mobile(model)
def test_conv3d_api( self, batch_size, in_channels_per_group, D, H, W, out_channels_per_group, groups, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_channelwise, use_fused, ): # Tests the correctness of the conv3d module. in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups input_feature_map_size = (D, H, W) kernel_size = (kernel_d, kernel_h, kernel_w) stride = (stride_d, stride_h, stride_w) padding = (pad_d, pad_h, pad_w) dilation = (dilation, dilation, dilation) with override_quantized_engine('fbgemm'): if use_fused: module_name = "QuantizedConvReLU3d" qconv_module = nnq_fused.ConvReLU3d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode=pad_mode) else: module_name = "QuantizedConv3d" qconv_module = nnq.Conv3d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode=pad_mode) conv_module = nn.Conv3d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode=pad_mode) if use_fused: relu_module = nn.ReLU() conv_module = nni.ConvReLU3d(conv_module, relu_module) conv_module = conv_module.float() self._test_conv_api_impl( module_name, qconv_module, conv_module, batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise)
def test_dropout(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualDropoutQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.dropout), nnq.Dropout) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model) checkQuantized(model) model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn, [self.train_data]) checkQuantized(model)
def test_compare_weights(self): r"""Compare the weights of float and quantized conv layer """ def compare_and_validate_results(float_model, q_model): weight_dict = compare_weights(float_model.state_dict(), q_model.state_dict()) self.assertEqual(len(weight_dict), 1) for k, v in weight_dict.items(): self.assertTrue(v["float"].shape == v["quantized"].shape) for qengine in supported_qengines: with override_quantized_engine(qengine): model_list = [ AnnotatedConvModel(qengine), AnnotatedConvBnReLUModel(qengine), ] for model in model_list: model.eval() if hasattr(model, "fuse_model"): model.fuse_model() q_model = quantize(model, default_eval_fn, self.img_data) compare_and_validate_results(model, q_model)
def test_hoist_conv_packed_params(self): if 'qnnpack' not in torch.backends.quantized.supported_engines: return class Standalone(nn.Module): def __init__(self): super(Standalone, self).__init__() self.quant = torch.quantization.QuantStub() self.conv1 = nn.Conv2d(1, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) self.relu = nn.ReLU() self.dequant = torch.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.conv2(x) x = self.relu(x) x = self.dequant(x) return x def fuse_model(self): torch.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True) pass class Child(nn.Module): def __init__(self): super(Child, self).__init__() self.conv1 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) return x class Parent(nn.Module): def __init__(self): super(Parent, self).__init__() self.quant = torch.quantization.QuantStub() self.conv1 = nn.Conv2d(1, 1, 1) self.child = Child() # TODO: test nn.Sequential after #42039 is fixed self.dequant = torch.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.child(x) x = self.dequant(x) return x def fuse_model(self): pass with override_quantized_engine('qnnpack'): def _quant_script_and_optimize(model): model.qconfig = torch.quantization.get_default_qconfig('qnnpack') model.fuse_model() torch.quantization.prepare(model, inplace=True) model(torch.randn(4, 1, 4, 4)) torch.quantization.convert(model, inplace=True) model = torch.jit.script(model) model_optim = optimize_for_mobile(model) return model, model_optim # basic case m, m_optim = _quant_script_and_optimize(Standalone()) FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \ .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) self.assertFalse(hasattr(m_optim, "conv2")) data = torch.randn(4, 1, 4, 4) m_res = m(data) m_optim_res = m_optim(data) torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3) # generic case m, m_optim = _quant_script_and_optimize(Parent()) FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \ .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) self.assertFalse(hasattr(m_optim, "child")) data = torch.randn(4, 1, 4, 4) m_res = m(data) m_optim_res = m_optim(data) torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3)