def test_conv_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualConvLinearQATModel()

                model = prepare_qat(model)
                self.checkObservers(model)

                test_only_train_fn(model, self.img_data_2d_train)
                model = convert(model)

                def checkQuantized(model):
                    self.assertEqual(type(model.conv), nnq.Conv2d)
                    self.assertEqual(type(model.fc1), nnq.Linear)
                    self.assertEqual(type(model.fc2), nnq.Linear)
                    test_only_eval_fn(model, self.img_data_2d)
                    self.checkScriptable(model, self.img_data_2d)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = ManualConvLinearQATModel()
                model = quantize_qat(model, test_only_train_fn,
                                     [self.img_data_2d_train])
                checkQuantized(model)
    def test_qat_convbn_fused_syncbn_replacement(self):
        """
        Tests that SyncBatchNorm replacement works for fused ConvBN.
        """
        if 'fbgemm' not in torch.backends.quantized.supported_engines:
            return
        with override_quantized_engine('fbgemm'):
            # create conv-bn
            class Model(nn.Module):
                def __init__(self):
                    super(Model, self).__init__()
                    self.conv = nn.Conv2d(4, 1, 3, padding=1)
                    self.bn = nn.BatchNorm2d(1)

                def forward(self, x):
                    x = self.conv(x)
                    x = self.bn(x)
                    return x

            model = Model()
            # fuse it
            fused_model = torch.quantization.fuse_modules(
                model,
                [['conv', 'bn']],
            )
            # convert to QAT
            fused_model.qconfig = torch.quantization.get_default_qconfig(
                'fbgemm')
            torch.quantization.prepare_qat(fused_model, inplace=True)
            # replace with DDP
            fused_model = nn.SyncBatchNorm.convert_sync_batchnorm(fused_model)
            self.assertTrue(isinstance(fused_model.conv.bn, nn.SyncBatchNorm),
                            "Expected BN to be converted to SyncBN")
    def test_fusion_conv_with_bias(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ModelForFusionWithBias().train()
                # output with no fusion.
                out_ref = model(self.img_data_2d[0][0])

                model.qconfig = QConfig(activation=torch.nn.Identity,
                                        weight=torch.nn.Identity)
                model = fuse_modules(model, [["conv1", "bn1", "relu1"],
                                             ["conv2", "bn2"]])
                prep_model = prepare_qat(model, inplace=False)
                # output with fusion but no observers.
                out_fused = prep_model(self.img_data_2d[0][0])
                self.assertEqual(out_ref, out_fused)

                model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
                prepare_qat(model, inplace=True)

                model(self.img_data_2d[0][0])

                def checkQAT(model):
                    self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
                    self.assertEqual(type(model.bn1), nn.Identity)
                    self.assertEqual(type(model.relu1), nn.Identity)
                    self.assertEqual(type(model.conv2), nniqat.ConvBn2d)
                    self.assertEqual(type(model.bn2), nn.Identity)

                checkQAT(model)
Exemple #4
0
    def test_defused_embedding_bag_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = DeFusedEmbeddingBagLinear().train()
                model = prepare_qat(
                    model, mapping=get_embedding_qat_module_mappings())
                self.checkObservers(model)

                test_only_train_fn(model, self.embed_linear_data_train)
                # make sure activation_post_process is inserted after Linear.
                self.assertEqual(type(model.linear.activation_post_process),
                                 FusedMovingAvgObsFakeQuantize)
                # make sure that Embedding has a noop for activation.
                self.assertEqual(type(model.emb.activation_post_process),
                                 NoopObserver)

                model = convert(
                    model,
                    mapping=get_embedding_static_quant_module_mappings())

                def checkQuantized(model):
                    # make sure Embedding is now a QuantizedEmbedding
                    self.assertEqual(type(model.emb), nn.quantized.Embedding)
                    # make sure Linear is now a QuantizedLinear
                    self.assertEqual(type(model.linear), nn.quantized.Linear)

                    test_only_eval_fn(model, self.embed_data)
                    self.checkScriptable(model, self.embed_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)
    def test_conv1d_api(
        self, batch_size, in_channels_per_group, L, out_channels_per_group,
        groups, kernel, stride, pad, dilation,
        X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
        use_bias, use_channelwise, qengine,
    ):
        # Tests the correctness of the conv1d function.
        if qengine not in torch.backends.quantized.supported_engines:
            return
        if qengine == 'qnnpack':
            if IS_PPC or TEST_WITH_UBSAN:
                return
            use_channelwise = False

        input_feature_map_size = (L, )
        kernel_size = (kernel, )
        stride = (stride, )
        padding = (pad, )
        dilation = (dilation, )

        with override_quantized_engine(qengine):
            qconv_fn = qF.conv1d
            conv_fn = F.conv1d
            self._test_conv_api_impl(
                qconv_fn, conv_fn, batch_size, in_channels_per_group,
                input_feature_map_size, out_channels_per_group, groups,
                kernel_size, stride, padding, dilation, X_scale, X_zero_point,
                W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
                use_channelwise)
 def test_fusion_sequential_model_eval(self):
     for qengine in supported_qengines:
         with override_quantized_engine(qengine):
             model = ModelWithSequentialFusion().eval()
             model.to(torch.float)
             fuse_modules(model, [['conv1', 'relu1'] ,
                                  ['features.0.0', 'features.0.1', 'features.0.2'],
                                  ['features.1.0', 'features.1.1', 'features.1.2'],
                                  ['features.2.0', 'features.2.1', 'features.2.2'],
                                  ['classifier.0', 'classifier.1']], inplace=True)
             self.assertEqual(type(model.conv1), nni.ConvReLU2d,
                              msg="Fused Conv + Relu: nni.ConvReLU2d")
             self.assertEqual(type(model.conv1[0]), nn.Conv2d,
                              msg="Fused Conv + Relu: Conv2d")
             self.assertEqual(type(model.conv1[1]), nn.ReLU,
                              msg="Fused Conv + Relu: Relu")
             self.assertEqual(type(model.relu1), nn.Identity,
                              msg="Fused Conv + Relu: Identity")
             for i in range(3):
                 self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
                                  msg="Fused submodule Conv + folded BN")
                 self.assertEqual(type(model.features[i][1]), nn.Identity,
                                  msg="Fused submodule (skipped BN)")
                 self.assertEqual(type(model.features[i][2]), nn.Identity,
                                  msg="Non-fused submodule Conv")
             self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
             self.assertEqual(type(model.classifier[1]), nn.Identity)
             model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
             prepare(model, inplace=True)
             self.checkObservers(model)
             model(self.img_data_2d[0][0])
             convert(model, inplace=True)
             model(self.img_data_2d[1][0])
             self.checkModelWithSequentialQuantized(model)
Exemple #7
0
    def test_embedding_bag_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualEmbeddingBagLinear().train()
                model = prepare_qat(model)
                self.checkObservers(model)

                train_indices = [[
                    torch.randint(0, 10, (12, 12)),
                    torch.randn((12, 1))
                ] for _ in range(2)]
                eval_output = [[torch.randint(0, 10, (12, 1))]]

                test_only_train_fn(model, train_indices)
                # make sure not activation_post_process is inserted for EmbeddingBag
                self.assertFalse(hasattr(model, "activation_post_process"))
                model = convert(model)

                def checkQuantized(model):
                    # Make sure EmbeddingBag is now a quantized EmbeddingBag.
                    self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
                    # Also test that Linear has been quantized.
                    self.assertTrue(type(model.linear), nnq.Linear)

                    test_only_eval_fn(model, eval_output)
                    self.checkScriptable(model, eval_output)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = ManualEmbeddingBagLinear()
                model = quantize_qat(model, test_only_train_fn,
                                     [train_indices])
                checkQuantized(model)
    def test_conv3d_api(
        self, batch_size, in_channels_per_group, D, H, W,
        out_channels_per_group, groups, kernel_d, kernel_h, kernel_w,
        stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale,
        X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
        use_channelwise, qengine,
    ):
        # Tests the correctness of the conv3d function.
        # Currently conv3d only supports FbGemm engine

        if qengine not in torch.backends.quantized.supported_engines:
            return

        input_feature_map_size = (D, H, W)
        kernel_size = (kernel_d, kernel_h, kernel_w)
        stride = (stride_d, stride_h, stride_w)
        padding = (pad_d, pad_h, pad_w)
        dilation = (dilation, dilation, dilation)

        with override_quantized_engine(qengine):
            qconv_fn = qF.conv3d
            conv_fn = F.conv3d
            self._test_conv_api_impl(
                qconv_fn, conv_fn, batch_size, in_channels_per_group,
                input_feature_map_size, out_channels_per_group, groups,
                kernel_size, stride, padding, dilation, X_scale, X_zero_point,
                W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
                use_channelwise)
Exemple #9
0
    def test_compare_model_stub(self):
        r"""Compare the output of quantized conv layer and its float shadow module
        """
        def compare_and_validate_results(float_model, q_model,
                                         module_swap_list, data):
            ob_dict = compare_model_stub(float_model, q_model,
                                         module_swap_list, data, ShadowLogger)
            self.assertEqual(len(ob_dict), 1)
            for k, v in ob_dict.items():
                self.assertTrue(v["float"].shape == v["quantized"].shape)

        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model_list = [
                    AnnotatedConvModel(qengine),
                    AnnotatedConvBnReLUModel(qengine),
                ]
                data = self.img_data[0][0]
                module_swap_list = [
                    nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d
                ]
                for model in model_list:
                    model.eval()
                    if hasattr(model, "fuse_model"):
                        model.fuse_model()
                    q_model = quantize(model, default_eval_fn, self.img_data)
                    compare_and_validate_results(model, q_model,
                                                 module_swap_list, data)

                # Test adding stub to sub module
                model = ModelWithSubModules().eval()
                q_model = quantize(model, default_eval_fn, self.img_data)
                module_swap_list = [SubModule]
                ob_dict = compare_model_stub(model, q_model, module_swap_list,
                                             data, ShadowLogger)
                self.assertTrue(isinstance(q_model.mod1, Shadow))
                self.assertFalse(isinstance(q_model.conv, Shadow))
                for k, v in ob_dict.items():
                    torch.testing.assert_allclose(v["float"],
                                                  v["quantized"].dequantize())

                # Test adding stub to functionals
                model = ModelWithFunctionals().eval()
                model.qconfig = torch.quantization.get_default_qconfig(
                    "fbgemm")
                q_model = prepare(model, inplace=False)
                q_model(data)
                q_model = convert(q_model)
                module_swap_list = [nnq.FloatFunctional]
                ob_dict = compare_model_stub(model, q_model, module_swap_list,
                                             data, ShadowLogger)
                self.assertEqual(len(ob_dict), 6)
                self.assertTrue(isinstance(q_model.mycat, Shadow))
                self.assertTrue(isinstance(q_model.myadd, Shadow))
                self.assertTrue(isinstance(q_model.mymul, Shadow))
                self.assertTrue(isinstance(q_model.myadd_relu, Shadow))
                self.assertTrue(isinstance(q_model.my_scalar_add, Shadow))
                self.assertTrue(isinstance(q_model.my_scalar_mul, Shadow))
                for k, v in ob_dict.items():
                    self.assertTrue(v["float"].shape == v["quantized"].shape)
    def test_embedding_bag_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualEmbeddingBagLinear().train()
                model = prepare_qat(
                    model, mapping=get_embedding_qat_module_mappings())
                self.checkObservers(model)

                test_only_train_fn(model, self.embed_linear_data_train)
                # make sure not activation_post_process is inserted for EmbeddingBag
                self.assertFalse(hasattr(model, "activation_post_process"))
                # make sure that FakeQuant zero_points are correct dtype
                self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype,
                                 torch.float32)
                self.assertEqual(
                    model.linear.weight_fake_quant.zero_point.dtype,
                    torch.int32)
                model = convert(
                    model,
                    mapping=get_embedding_static_quant_module_mappings())

                def checkQuantized(model):
                    # Make sure EmbeddingBag is now a quantized EmbeddingBag.
                    self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
                    # Also test that Linear has been quantized.
                    self.assertTrue(type(model.linear), nnq.Linear)

                    test_only_eval_fn(model, self.embed_data)
                    self.checkScriptable(model, self.embed_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = ManualEmbeddingBagLinear()
    def test_dynamic_qat_linear(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                # Dynamic QAT without memoryless observers should fail
                with self.assertRaisesRegex(
                        ValueError,
                        "Dynamic QAT requires a memoryless observer." +
                        "This means a MovingAverage observer with averaging constant equal to 1"
                ):
                    model = ManualLinearDynamicQATModel(default_qat_qconfig)
                    model = prepare_qat(
                        model, mapping={torch.nn.Linear: nnqatd.Linear})

                model = ManualLinearDynamicQATModel()
                model = prepare_qat(model,
                                    mapping={torch.nn.Linear: nnqatd.Linear})
                self.assertEqual(type(model.fc1), nnqatd.Linear)
                self.assertEqual(type(model.fc2), nnqatd.Linear)
                self.checkObservers(model)
                test_only_train_fn(model, self.train_data)
                model = convert(model, mapping={nnqatd.Linear: nnqd.Linear})
                self.assertEqual(type(model.fc1), nnqd.Linear)
                self.assertEqual(type(model.fc2), nnqd.Linear)
                test_only_eval_fn(model, self.calib_data)
                self.checkScriptable(model, self.calib_data)
                self.checkNoQconfig(model)
 def test_weight_only_activation_only_fakequant(self):
     for qengine in supported_qengines:
         with override_quantized_engine(qengine):
             torch.manual_seed(67)
             calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
             eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
             qconfigset = set([
                 torch.quantization.default_weight_only_qconfig,
                 torch.quantization.default_activation_only_qconfig
             ])
             SQNRTarget = [35, 45]
             for idx, qconfig in enumerate(qconfigset):
                 my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
                 my_model.eval()
                 out_ref = my_model(eval_data)
                 fq_model = torch.quantization.QuantWrapper(my_model)
                 fq_model.train()
                 fq_model.qconfig = qconfig
                 torch.ao.quantization.fuse_modules(
                     fq_model.module, [['conv1', 'bn1', 'relu1']],
                     inplace=True)
                 torch.quantization.prepare_qat(fq_model)
                 fq_model.eval()
                 fq_model.apply(torch.quantization.disable_fake_quant)
                 fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
                 fq_model(calib_data)
                 fq_model.apply(torch.quantization.enable_fake_quant)
                 fq_model.apply(torch.quantization.disable_observer)
                 out_fq = fq_model(eval_data)
                 SQNRdB = 20 * torch.log10(
                     torch.norm(out_ref) / torch.norm(out_ref - out_fq))
                 self.assertGreater(
                     SQNRdB,
                     SQNRTarget[idx],
                     msg='Quantized model numerics diverge from float')
    def _compare_script_and_mobile(self, model: torch.nn.Module,
                                   input: torch.Tensor):
        qengine = "qnnpack"
        with override_quantized_engine(qengine):
            script_module = torch.jit.script(model)
            script_module_result = script_module(input)

            max_retry = 5
            for retry in range(1, max_retry + 1):
                # retires `max_retry` times; breaks iff succeeds else throws exception
                try:
                    buffer = io.BytesIO(
                        script_module._save_to_buffer_for_lite_interpreter())
                    buffer.seek(0)
                    mobile_module = _load_for_lite_interpreter(buffer)

                    mobile_module_result = mobile_module(input)

                    torch.testing.assert_allclose(script_module_result,
                                                  mobile_module_result)
                    mobile_module_forward_result = mobile_module.forward(input)
                    torch.testing.assert_allclose(
                        script_module_result, mobile_module_forward_result)

                    mobile_module_run_method_result = mobile_module.run_method(
                        "forward", input)
                    torch.testing.assert_allclose(
                        script_module_result, mobile_module_run_method_result)
                except AssertionError as e:
                    if retry == max_retry:
                        raise e
                    else:
                        continue
                break
 def test_float_quant_compare_per_tensor(self):
     for qengine in supported_qengines:
         with override_quantized_engine(qengine):
             torch.manual_seed(42)
             my_model = ModelMultipleOps().to(torch.float32)
             my_model.eval()
             calib_data = torch.rand(1024, 3, 15, 15, dtype=torch.float32)
             eval_data = torch.rand(1, 3, 15, 15, dtype=torch.float32)
             out_ref = my_model(eval_data)
             qModel = torch.quantization.QuantWrapper(my_model)
             qModel.eval()
             qModel.qconfig = torch.quantization.default_qconfig
             torch.ao.quantization.fuse_modules(qModel.module,
                                                [['conv1', 'bn1', 'relu1']],
                                                inplace=True)
             torch.quantization.prepare(qModel, inplace=True)
             qModel(calib_data)
             torch.quantization.convert(qModel, inplace=True)
             out_q = qModel(eval_data)
             SQNRdB = 20 * torch.log10(
                 torch.norm(out_ref) / torch.norm(out_ref - out_q))
             # Quantized model output should be close to floating point model output numerically
             # Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired
             # output
             self.assertGreater(
                 SQNRdB,
                 30,
                 msg=
                 'Quantized model numerics diverge from float, expect SQNR > 30 dB'
             )
 def test_linear_dynamic(self):
     for i, qengine in enumerate(supported_qengines):
         with override_quantized_engine(qengine):
             module_qint8 = nnqd.Linear(3, 1, bias_=True, dtype=torch.qint8)
             self._test_op(module_qint8, "qint8", input_size=[1, 3], input_quantized=False, generate=False, iter=i)
             if qengine == 'fbgemm':
                 module_float16 = nnqd.Linear(3, 1, bias_=True, dtype=torch.float16)
                 self._test_op(module_float16, "float16", input_size=[1, 3], input_quantized=False, generate=False)
    def _create_quantized_model(self, model_class: Type[torch.nn.Module],
                                **kwargs):
        qengine = "qnnpack"
        with override_quantized_engine(qengine):
            qconfig = torch.quantization.get_default_qconfig(qengine)
            model = model_class(**kwargs)
            model = quantize(model, test_only_eval_fn, [self.calib_data])

        return model
    def test_train_save_load_eval(self):
        r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict
        During eval, we first call prepare_qat and conver on the model and then load the state_dict
        and compare results against original model
        """
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = TwoLayerLinearModel()
                model = torch.ao.quantization.QuantWrapper(model)
                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(
                    qengine)
                model = prepare_qat(model)

                fq_state_dict = model.state_dict()

                test_only_train_fn(model, self.train_data)
                model = convert(model)

                quant_state_dict = model.state_dict()

                x = torch.rand(2, 5, dtype=torch.float)
                ref = model(x)

                # Create model again for eval. Check result using quantized state_dict
                model = TwoLayerLinearModel()
                model = torch.ao.quantization.QuantWrapper(model)
                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(
                    qengine)
                torch.ao.quantization.prepare_qat(model, inplace=True)
                new_state_dict = model.state_dict()

                # Check to make sure the model after prepare_qat has the same state_dict as original.
                self.assertEqual(set(fq_state_dict.keys()),
                                 set(new_state_dict.keys()))

                torch.ao.quantization.convert(model, inplace=True)
                model.eval()
                model.load_state_dict(quant_state_dict)
                out = model(x)
                self.assertEqual(ref, out)

                # Check model created using prepare has same state dict as quantized state_dict
                model = TwoLayerLinearModel()
                model.eval()
                model = torch.ao.quantization.QuantWrapper(model)
                model.qconfig = torch.ao.quantization.get_default_qconfig(
                    qengine)
                torch.ao.quantization.prepare(model, inplace=True)
                torch.ao.quantization.convert(model, inplace=True)
                self.assertEqual(set(model.state_dict().keys()),
                                 set(quant_state_dict.keys()))
                model.eval()
                model.load_state_dict(quant_state_dict)
                out = model(x)
                self.assertEqual(ref, out)
Exemple #18
0
    def test_compare_model_outputs(self):
        r"""Compare the output of conv layer in quantized model and corresponding
        output of conv layer in float model
        """
        def compare_and_validate_results(float_model, q_model, data):
            act_compare_dict = compare_model_outputs(float_model, q_model,
                                                     data)
            self.assertEqual(len(act_compare_dict), 2)
            expected_act_compare_dict_keys = {"conv.stats", "quant.stats"}
            self.assertTrue(
                act_compare_dict.keys() == expected_act_compare_dict_keys)
            for k, v in act_compare_dict.items():
                self.assertTrue(v["float"].shape == v["quantized"].shape)

        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model_list = [
                    AnnotatedConvModel(qengine),
                    AnnotatedConvBnReLUModel(qengine),
                ]
                data = self.img_data[0][0]
                module_swap_list = [
                    nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d
                ]
                for model in model_list:
                    model.eval()
                    if hasattr(model, "fuse_model"):
                        model.fuse_model()
                    q_model = quantize(model, default_eval_fn, self.img_data)
                    compare_and_validate_results(model, q_model, data)

                # Test functionals
                model = ModelWithFunctionals().eval()
                model.qconfig = torch.quantization.get_default_qconfig(
                    "fbgemm")
                q_model = prepare(model, inplace=False)
                q_model(data)
                q_model = convert(q_model)
                act_compare_dict = compare_model_outputs(model, q_model, data)
                self.assertEqual(len(act_compare_dict), 7)
                expected_act_compare_dict_keys = {
                    "mycat.stats",
                    "myadd.stats",
                    "mymul.stats",
                    "myadd_relu.stats",
                    "my_scalar_add.stats",
                    "my_scalar_mul.stats",
                    "quant.stats",
                }
                self.assertTrue(
                    act_compare_dict.keys() == expected_act_compare_dict_keys)
                for k, v in act_compare_dict.items():
                    self.assertTrue(v["float"].shape == v["quantized"].shape)
Exemple #19
0
    def test_module_with_shared_type_instances(self):
        class Child(nn.Module):
            def __init__(self):
                super(Child, self).__init__()
                self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)

            def forward(self, x):
                x = self.conv1(x)
                return x

        class Parent(nn.Module):
            def __init__(self):
                super(Parent, self).__init__()
                self.quant = torch.quantization.QuantStub()
                self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)
                self.child = Child()
                self.child2 = Child()
                self.dequant = torch.quantization.DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.conv1(x)
                x = self.child(x)
                x = self.child2(x)
                x = self.dequant(x)
                return x

        def _static_quant(model):
            qModel = torch.quantization.QuantWrapper(model)
            qModel.qconfig = torch.quantization.default_qconfig
            torch.quantization.prepare(qModel, inplace=True)
            qModel(torch.rand(4, 1, 4, 4, dtype=torch.float32))
            torch.quantization.convert(qModel, inplace=True)
            return model

        with override_quantized_engine('fbgemm'):
            data = torch.randn(4, 1, 4, 4, dtype=torch.float32)
            m = Parent().to(torch.float32)
            m = _static_quant(m)
            m = torch.jit.script(m)
            m.eval()
            torch._C._jit_pass_inline(m.graph)
            m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c))
            # Earlier bug resulted in _packed_params set to false.
            FileCheck().check_not('_packed_params = False').run(
                m_frozen._c.dump_to_str(True, True, False))

            m_res = m(data)
            # It used to segfault while running frozen module.
            m_frozen_res = m_frozen(data)
            self.assertEqual(m_res, m_frozen_res)
    def test_eval_only_fake_quant(self):
        r"""Using FakeQuant in evaluation only mode,
        this is useful for estimating accuracy loss when we quantize the
        network
        """
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualLinearQATModel(qengine)

                model = prepare_qat(model)
                self.checkObservers(model)

                model.eval()
                test_only_eval_fn(model, self.calib_data)
Exemple #21
0
 def test_fake_quant_true_quant_compare(self):
     for qengine in ["fbgemm", "qnnpack"]:
         if qengine not in torch.backends.quantized.supported_engines:
             continue
         if qengine == 'qnnpack':
             if IS_PPC or TEST_WITH_UBSAN:
                 continue
         with override_quantized_engine(qengine):
             torch.manual_seed(67)
             my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
             calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
             eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
             my_model.eval()
             out_ref = my_model(eval_data)
             fq_model = torch.quantization.QuantWrapper(my_model)
             fq_model.train()
             fq_model.qconfig = torch.quantization.default_qat_qconfig
             torch.quantization.fuse_modules(fq_model.module,
                                             [['conv1', 'bn1', 'relu1']],
                                             inplace=True)
             torch.quantization.prepare_qat(fq_model)
             fq_model.eval()
             fq_model.apply(torch.quantization.disable_fake_quant)
             fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
             fq_model(calib_data)
             fq_model.apply(torch.quantization.enable_fake_quant)
             fq_model.apply(torch.quantization.disable_observer)
             out_fq = fq_model(eval_data)
             SQNRdB = 20 * torch.log10(
                 torch.norm(out_ref) / torch.norm(out_ref - out_fq))
             # Quantized model output should be close to floating point model output numerically
             # Setting target SQNR to be 35 dB
             self.assertGreater(
                 SQNRdB,
                 35,
                 msg=
                 'Quantized model numerics diverge from float, expect SQNR > 35 dB'
             )
             torch.quantization.convert(fq_model)
             out_q = fq_model(eval_data)
             SQNRdB = 20 * torch.log10(
                 torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10))
             self.assertGreater(
                 SQNRdB,
                 60,
                 msg=
                 'Fake quant and true quant numerics diverge, expect SQNR > 60 dB'
             )
    def test_conv2d_api(
        self, batch_size, in_channels_per_group, H, W, out_channels_per_group,
        groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation,
        X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
        use_bias, use_fused, use_channelwise, qengine,
    ):
        # Tests the correctness of the conv2d module.
        if qengine not in torch.backends.quantized.supported_engines:
            return
        if qengine == 'qnnpack':
            if IS_PPC or TEST_WITH_UBSAN:
                return
            use_channelwise = False

        in_channels = in_channels_per_group * groups
        out_channels = out_channels_per_group * groups
        input_feature_map_size = (H, W)
        kernel_size = (kernel_h, kernel_w)
        stride = (stride_h, stride_w)
        padding = (pad_h, pad_w)
        dilation = (dilation, dilation)

        with override_quantized_engine(qengine):
            if use_fused:
                module_name = "QuantizedConvReLU2d"
                qconv_module = nnq_fused.ConvReLU2d(
                    in_channels, out_channels, kernel_size, stride, padding,
                    dilation, groups, use_bias, padding_mode="zeros")
            else:
                module_name = "QuantizedConv2d"
                qconv_module = nnq.Conv2d(
                    in_channels, out_channels, kernel_size, stride, padding,
                    dilation, groups, use_bias, padding_mode="zeros")

            conv_module = nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding,
                dilation, groups, use_bias, padding_mode="zeros")
            if use_fused:
                relu_module = nn.ReLU()
                conv_module = nni.ConvReLU2d(conv_module, relu_module)
            conv_module = conv_module.float()

            self._test_conv_api_impl(
                module_name, qconv_module, conv_module, batch_size,
                in_channels_per_group, input_feature_map_size,
                out_channels_per_group, groups, kernel_size, stride, padding,
                dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
                Y_zero_point, use_bias, use_fused, use_channelwise)
Exemple #23
0
    def test_fusion_conv_with_bias(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model_orig = ModelForFusionWithBias().train()

                # reference model
                model_ref = copy.deepcopy(model_orig)
                # output with no fusion.
                out_ref = model_ref(self.img_data_2d[0][0])

                # fused model
                model_orig.qconfig = QConfig(activation=torch.nn.Identity,
                                             weight=torch.nn.Identity)
                model = fuse_modules_qat(
                    model_orig,
                    [["conv1", "bn1", "relu1"],
                     ["conv2", "bn2"]])
                prep_model = prepare_qat(model, inplace=False)
                # output with fusion but no observers.
                out_fused = prep_model(self.img_data_2d[0][0])

                self.assertEqual(out_ref, out_fused)

                def checkBN(bn_ref, bn):
                    self.assertEqual(bn_ref.weight, bn.weight)
                    self.assertEqual(bn_ref.bias, bn.bias)
                    self.assertEqual(bn_ref.running_mean, bn.running_mean)
                    self.assertEqual(bn_ref.running_var, bn.running_var)

                checkBN(model_ref.bn1, prep_model.conv1.bn)
                checkBN(model_ref.bn2, prep_model.conv2.bn)

                model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
                prepare_qat(model, inplace=True)

                model(self.img_data_2d[0][0])

                def checkQAT(model):
                    self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
                    self.assertEqual(type(model.bn1), nn.Identity)
                    self.assertEqual(type(model.relu1), nn.Identity)
                    self.assertEqual(type(model.conv2), nniqat.ConvBn2d)
                    self.assertEqual(type(model.bn2), nn.Identity)

                checkQAT(model)
Exemple #24
0
    def test_record_observer(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = AnnotatedSingleLayerLinearModel()
                model.qconfig = default_debug_qconfig
                model = prepare(model)
                # run the evaluation and dump all tensors
                test_only_eval_fn(model, self.calib_data)
                test_only_eval_fn(model, self.calib_data)
                observer_dict = {}
                get_observer_dict(model, observer_dict)

                self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
                                'observer is not recorded in the dict')
                self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()),
                                 2 * len(self.calib_data))
                self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0],
                                 model(self.calib_data[0][0]))
    def test_qat_data_parallel(self):
        """
        Tests that doing QAT in nn.DataParallel does not crash.
        """
        if 'fbgemm' not in torch.backends.quantized.supported_engines:
            return
        with override_quantized_engine('fbgemm'):
            device = torch.device('cuda')

            model = nn.Sequential(
                torch.quantization.QuantStub(),
                nn.Conv2d(3, 1, 1, bias=False),
                nn.BatchNorm2d(1),
                nn.ReLU(),
                nn.Conv2d(1, 2, 3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(2),
                nn.AvgPool2d(14),
                nn.Sigmoid(),
                torch.quantization.DeQuantStub(),
            )

            torch.quantization.fuse_modules(model,
                                            [['1', '2', '3'], ['4', '5']],
                                            inplace=True)

            model.qconfig = torch.quantization.get_default_qat_qconfig(
                'fbgemm')
            torch.quantization.prepare_qat(model, inplace=True)
            model = nn.DataParallel(model, device_ids=[0, 1])
            model.to(device)
            model.train()

            for epoch in range(3):
                inputs = torch.rand(2, 3, 28, 28).to(device)
                model(inputs)
                if epoch >= 1:
                    model.apply(torch.quantization.disable_observer)
                if epoch >= 2:
                    model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
                quant_model = copy.deepcopy(model.module)
                quant_model = torch.quantization.convert(
                    quant_model.eval().cpu(), inplace=False)
                with torch.no_grad():
                    out = quant_model(torch.rand(1, 3, 28, 28))
    def test_quantized_conv_no_asan_failures(self):
        # There were ASAN failures when fold_conv_bn was run on
        # already quantized conv modules. Verifying that this does
        # not happen again.

        if 'qnnpack' not in torch.backends.quantized.supported_engines:
            return

        class Child(nn.Module):
            def __init__(self):
                super(Child, self).__init__()
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv2(x)
                return x

        class Parent(nn.Module):
            def __init__(self):
                super(Parent, self).__init__()
                self.quant = torch.ao.quantization.QuantStub()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.child = Child()
                self.dequant = torch.ao.quantization.DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.conv1(x)
                x = self.child(x)
                x = self.dequant(x)
                return x

        with override_quantized_engine('qnnpack'):
            model = Parent()
            model.qconfig = torch.ao.quantization.get_default_qconfig(
                'qnnpack')
            torch.ao.quantization.prepare(model, inplace=True)
            model(torch.randn(4, 1, 4, 4))
            torch.ao.quantization.convert(model, inplace=True)
            model = torch.jit.script(model)
            # this line should not have ASAN failures
            model_optim = optimize_for_mobile(model)
    def test_conv3d_api(
        self, batch_size, in_channels_per_group, D, H, W,
        out_channels_per_group, groups, kernel_d, kernel_h, kernel_w,
        stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, pad_mode, dilation,
        X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
        use_bias, use_channelwise, use_fused,
    ):
        # Tests the correctness of the conv3d module.
        in_channels = in_channels_per_group * groups
        out_channels = out_channels_per_group * groups
        input_feature_map_size = (D, H, W)
        kernel_size = (kernel_d, kernel_h, kernel_w)
        stride = (stride_d, stride_h, stride_w)
        padding = (pad_d, pad_h, pad_w)
        dilation = (dilation, dilation, dilation)
        with override_quantized_engine('fbgemm'):
            if use_fused:
                module_name = "QuantizedConvReLU3d"
                qconv_module = nnq_fused.ConvReLU3d(
                    in_channels, out_channels, kernel_size, stride, padding,
                    dilation, groups, use_bias, padding_mode=pad_mode)
            else:
                module_name = "QuantizedConv3d"
                qconv_module = nnq.Conv3d(
                    in_channels, out_channels, kernel_size, stride, padding,
                    dilation, groups, use_bias, padding_mode=pad_mode)

            conv_module = nn.Conv3d(
                in_channels, out_channels, kernel_size, stride, padding,
                dilation, groups, use_bias, padding_mode=pad_mode)
            if use_fused:
                relu_module = nn.ReLU()
                conv_module = nni.ConvReLU3d(conv_module, relu_module)
            conv_module = conv_module.float()

            self._test_conv_api_impl(
                module_name, qconv_module, conv_module, batch_size,
                in_channels_per_group, input_feature_map_size,
                out_channels_per_group, groups, kernel_size, stride, padding,
                pad_mode, dilation, X_scale, X_zero_point, W_scale,
                W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused,
                use_channelwise)
    def test_dropout(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualDropoutQATModel(qengine)
                model = prepare_qat(model)
                self.checkObservers(model)
                test_only_train_fn(model, self.train_data)
                model = convert(model)

                def checkQuantized(model):
                    self.assertEqual(type(model.fc1), nnq.Linear)
                    self.assertEqual(type(model.dropout), nnq.Dropout)
                    test_only_eval_fn(model, self.calib_data)
                    self.checkScriptable(model, self.calib_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = quantize_qat(ManualDropoutQATModel(qengine),
                                     test_only_train_fn, [self.train_data])
                checkQuantized(model)
Exemple #29
0
    def test_compare_weights(self):
        r"""Compare the weights of float and quantized conv layer
        """
        def compare_and_validate_results(float_model, q_model):
            weight_dict = compare_weights(float_model.state_dict(),
                                          q_model.state_dict())
            self.assertEqual(len(weight_dict), 1)
            for k, v in weight_dict.items():
                self.assertTrue(v["float"].shape == v["quantized"].shape)

        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model_list = [
                    AnnotatedConvModel(qengine),
                    AnnotatedConvBnReLUModel(qengine),
                ]
                for model in model_list:
                    model.eval()
                    if hasattr(model, "fuse_model"):
                        model.fuse_model()
                    q_model = quantize(model, default_eval_fn, self.img_data)
                    compare_and_validate_results(model, q_model)
    def test_hoist_conv_packed_params(self):

        if 'qnnpack' not in torch.backends.quantized.supported_engines:
            return

        class Standalone(nn.Module):
            def __init__(self):
                super(Standalone, self).__init__()
                self.quant = torch.quantization.QuantStub()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)
                self.relu = nn.ReLU()
                self.dequant = torch.quantization.DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.conv1(x)
                x = self.conv2(x)
                x = self.relu(x)
                x = self.dequant(x)
                return x

            def fuse_model(self):
                torch.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True)
                pass

        class Child(nn.Module):
            def __init__(self):
                super(Child, self).__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                return x

        class Parent(nn.Module):
            def __init__(self):
                super(Parent, self).__init__()
                self.quant = torch.quantization.QuantStub()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.child = Child()
                # TODO: test nn.Sequential after #42039 is fixed
                self.dequant = torch.quantization.DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.conv1(x)
                x = self.child(x)
                x = self.dequant(x)
                return x

            def fuse_model(self):
                pass

        with override_quantized_engine('qnnpack'):
            def _quant_script_and_optimize(model):
                model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
                model.fuse_model()
                torch.quantization.prepare(model, inplace=True)
                model(torch.randn(4, 1, 4, 4))
                torch.quantization.convert(model, inplace=True)
                model = torch.jit.script(model)
                model_optim = optimize_for_mobile(model)
                return model, model_optim

            # basic case

            m, m_optim = _quant_script_and_optimize(Standalone())
            FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
                       .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
                       .run(m_optim.graph)
            self.assertFalse(hasattr(m_optim, "conv1"))
            self.assertFalse(hasattr(m_optim, "conv2"))

            data = torch.randn(4, 1, 4, 4)
            m_res = m(data)
            m_optim_res = m_optim(data)
            torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3)

            # generic case

            m, m_optim = _quant_script_and_optimize(Parent())
            FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
                       .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
                       .run(m_optim.graph)
            self.assertFalse(hasattr(m_optim, "conv1"))
            self.assertFalse(hasattr(m_optim, "child"))

            data = torch.randn(4, 1, 4, 4)
            m_res = m(data)
            m_optim_res = m_optim(data)
            torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3)