Example #1
0
    def test_convert_without_squash_mask(self):
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))

        sparsifier.prepare(mod, config=sparse_config)
        tq.prepare(mod, inplace=True)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(mod[5].weight)
        mod(torch.randn(1, 4, 4, 4))
        tq.convert(mod, inplace=True)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
Example #2
0
    def test_s_prep_before_q_prep(self):
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))

        sparsifier.prepare(mod, config=sparse_config)
        tq.prepare(mod, inplace=True)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))

        _squash_mask_calibrate_and_convert(
            mod, sparsifier, torch.randn(1, 4, 4, 4)
        )

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
Example #3
0
def prepare_model_outputs(
    float_module: nn.Module,
    q_module: nn.Module,
    logger_cls=OutputLogger,
    allow_list=None
) -> None:
    r"""Prepare the model by attaching the logger to both float module
    and quantized module if they are in the allow_list.

    Args:
        float_module: float module used to generate the q_module
        q_module: module quantized from float_module
        logger_cls: type of logger to be attached to float_module and q_module
        allow_list: list of module types to attach logger
    """
    torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
    if allow_list is None:
        allow_list = get_default_compare_output_module_list()

    qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
    float_module.qconfig = qconfig_debug  # type: ignore[assignment]
    prepare(float_module, inplace=True, allow_list=allow_list)
    q_module.qconfig = qconfig_debug  # type: ignore[assignment]
    prepare(
        q_module,
        inplace=True,
        allow_list=allow_list,
        observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
    )
Example #4
0
 def test_fusion_sequential_model_eval(self):
     for qengine in supported_qengines:
         with override_quantized_engine(qengine):
             model = ModelWithSequentialFusion().eval()
             model.to(torch.float)
             fuse_modules(model, [['conv1', 'relu1'] ,
                                  ['features.0.0', 'features.0.1', 'features.0.2'],
                                  ['features.1.0', 'features.1.1', 'features.1.2'],
                                  ['features.2.0', 'features.2.1', 'features.2.2'],
                                  ['classifier.0', 'classifier.1']], inplace=True)
             self.assertEqual(type(model.conv1), nni.ConvReLU2d,
                              msg="Fused Conv + Relu: nni.ConvReLU2d")
             self.assertEqual(type(model.conv1[0]), nn.Conv2d,
                              msg="Fused Conv + Relu: Conv2d")
             self.assertEqual(type(model.conv1[1]), nn.ReLU,
                              msg="Fused Conv + Relu: Relu")
             self.assertEqual(type(model.relu1), nn.Identity,
                              msg="Fused Conv + Relu: Identity")
             for i in range(3):
                 self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
                                  msg="Fused submodule Conv + folded BN")
                 self.assertEqual(type(model.features[i][1]), nn.Identity,
                                  msg="Fused submodule (skipped BN)")
                 self.assertEqual(type(model.features[i][2]), nn.Identity,
                                  msg="Non-fused submodule Conv")
             self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
             self.assertEqual(type(model.classifier[1]), nn.Identity)
             model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
             prepare(model, inplace=True)
             self.checkObservers(model)
             model(self.img_data_2d[0][0])
             convert(model, inplace=True)
             model(self.img_data_2d[1][0])
             self.checkModelWithSequentialQuantized(model)
    def test_batchnorm_relu_basic(self):
        """
        Basic test of the PyTorch 3D batchnorm RELU Node on Glow.
        """

        class SimpleQuantizedBatchNormRelu(nn.Module):
            def __init__(self, w, b, m, v):
                super(SimpleQuantizedBatchNormRelu, self).__init__()
                self.bn = torch.nn.BatchNorm3d(4)
                self.relu = torch.nn.ReLU()
                self.bn.weight = torch.nn.Parameter(w)
                self.bn.bias = torch.nn.Parameter(b)
                self.bn.running_mean = m
                self.bn.running_var = v
                self.q = QuantStub()
                self.dq = DeQuantStub()

            def forward(self, x):
                qx = self.q(x)
                qy = self.bn(qx)
                qy_relu = self.relu(qy)
                y = self.dq(qy_relu)
                return y

        C = 4
        weight = torch.ones(C) + torch.rand(C) * 0.001
        bias = torch.rand(C) * 0.0001
        running_mean = torch.zeros(C)
        running_var = torch.ones(C)

        inputs = torch.randn((10, C, 2, 3, 4), requires_grad=False)
        model = SimpleQuantizedBatchNormRelu(weight, bias, running_mean, running_var)
        model.eval()
        model.qconfig = my_qconfig
        modules_to_fuse = [["bn", "relu"]]
        fuse_modules(model, modules_to_fuse, inplace=True)
        prepare(model, inplace=True)
        model.forward(inputs)
        convert(model, inplace=True)

        # Because of the difference of quantization between PyTorch & Glow
        # We set eps big enough.
        # Batchnorm introduced great accuracy issues, which could create up to
        # ~1e-2 difference in some rare cases. In order to prevent this test
        # to be flaky, atol is set to be 0.1 and rtol is set to 0.00001.
        utils.compare_tracing_methods(
            model,
            inputs,
            fusible_ops={"quantized::batch_norm3d_relu"},
            atol=1e-1,
            rtol=1e-5,
            skip_to_glow=True,
        )
Example #6
0
    def test_fusion_before_s_prep(self):
        (
            mod,
            sparsifier,
            _,
        ) = self._get_model_and_sparsifier_and_sparse_config()
        tq.fuse_modules(mod, [["5", "6"]], inplace=True)

        # its absolutely broken by fusion but will still work if you put the correct fqn in
        sparse_config = [
            {
                "tensor_fqn": "5.0.weight",
                "sparsity_level": 0.7,
                "sparse_block_shape": (1, 4),
                "zeros_per_block": 4,
            },
            {
                "tensor_fqn": "0.weight"
            },
        ]

        sparsifier.prepare(mod, config=sparse_config)
        mod[5].qconfig = tq.get_default_qconfig("fbgemm")
        tq.prepare(mod, inplace=True)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5][0], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
        sparsifier.step()
        sparsity_level = self._calculate_sparsity(mod[5][0].weight)
        mod(torch.randn(1, 4, 4, 4))
        tq.convert(mod, inplace=True)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(
            isinstance(mod[5], torch.nn.intrinsic.quantized.LinearReLU))
        self.assertEqual(
            mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(sparsity_level,
                                      sparse_config[0]["sparsity_level"])
        self.assertGreaterAlmostEqual(cur_sparsity,
                                      sparse_config[0]["sparsity_level"])
Example #7
0
    def test_compare_model_outputs_functional_static(self):
        r"""Compare the output of functional layer in static quantized model and corresponding
        output of conv layer in float model
        """
        qengine = torch.backends.quantized.engine

        model = ModelWithFunctionals().eval()
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
        q_model = prepare(model, inplace=False)
        q_model(self.img_data_2d[0][0])
        q_model = convert(q_model)
        act_compare_dict = compare_model_outputs(model, q_model,
                                                 self.img_data_2d[0][0])
        self.assertEqual(len(act_compare_dict), 7)
        expected_act_compare_dict_keys = {
            "mycat.stats",
            "myadd.stats",
            "mymul.stats",
            "myadd_relu.stats",
            "my_scalar_add.stats",
            "my_scalar_mul.stats",
            "quant.stats",
        }
        self.assertTrue(
            act_compare_dict.keys() == expected_act_compare_dict_keys)
        for k, v in act_compare_dict.items():
            self.assertTrue(len(v["float"]) == len(v["quantized"]))
            for i, val in enumerate(v["quantized"]):
                self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
    def test_compare_model_stub_functional_static(self):
        r"""Compare the output of static quantized functional layer and its float shadow module"""

        qengine = torch.backends.quantized.engine

        model = ModelWithFunctionals().eval()
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
        q_model = prepare(model, inplace=False)
        q_model(self.img_data_2d[0][0])
        q_model = convert(q_model)
        module_swap_list = [nnq.FloatFunctional]
        ob_dict = compare_model_stub(
            model, q_model, module_swap_list, self.img_data_2d[0][0]
        )
        self.assertEqual(len(ob_dict), 6)
        self.assertTrue(isinstance(q_model.mycat, Shadow))
        self.assertTrue(isinstance(q_model.myadd, Shadow))
        self.assertTrue(isinstance(q_model.mymul, Shadow))
        self.assertTrue(isinstance(q_model.myadd_relu, Shadow))
        self.assertTrue(isinstance(q_model.my_scalar_add, Shadow))
        self.assertTrue(isinstance(q_model.my_scalar_mul, Shadow))
        for k, v in ob_dict.items():
            self.assertTrue(len(v["float"]) == len(v["quantized"]))
            for i, val in enumerate(v["quantized"]):
                self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
    def test_fixed_qparam_ops(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.sigmoid = torch.nn.Sigmoid()
                self.hardsigmoid = torch.nn.Hardsigmoid()
                self.tanh = torch.nn.Tanh()
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.sigmoid(x)
                x = self.hardsigmoid(x)
                x = self.tanh(x)
                x = self.dequant(x)
                return x

        m = M().train()
        m.qconfig = default_qat_qconfig
        m = prepare_qat(m)
        for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
            self.assertEqual(type(getattr(m, attr).activation_post_process),
                             FixedQParamsFakeQuantize)
        data = torch.randn(1, 3, 2, 4)
        before_convert = m(data)
        m = convert(m)
        after_convert = m(data)
        self.assertEqual(before_convert, after_convert)
        # make sure activation post process is removed
        for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
            # verify fake quant module is removd
            self.assertFalse(
                hasattr(getattr(m, attr), 'activation_post_process'))
            # verify that hooks are removed
            self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)

        # make sure no fake quantize module is inserted for eval mode

        def checkNoFQModule(m):
            for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
                self.assertFalse(
                    hasattr(getattr(m, attr), "activation_post_process"))
                self.assertTrue(
                    len(getattr(m, attr)._forward_hooks.items()) == 0)

        m = M().eval()
        m.qconfig = default_qconfig
        m = prepare(m)
        checkNoFQModule(m)
        m = convert(m)
        checkNoFQModule(m)
Example #10
0
    def test_q_prep_before_s_prep(self):
        (
            mod,
            sparsifier,
            sparse_config,
        ) = self._get_model_and_sparsifier_and_sparse_config()

        tq.prepare(mod, inplace=True)
        sparsifier.prepare(mod, config=sparse_config)

        # check that correct modules had parametrizations added
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))
        # check that correct observers were inserted
        self.assertTrue(hasattr(mod[5], "activation_post_process"))

        self._squash_mask_calibrate_and_convert(mod, sparsifier,
                                                torch.randn(1, 4, 4, 4))

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
        self.assertEqual(
            mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
Example #11
0
    def test_record_observer(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = AnnotatedSingleLayerLinearModel()
                model.qconfig = default_debug_qconfig
                model = prepare(model)
                # run the evaluation and dump all tensors
                test_only_eval_fn(model, self.calib_data)
                test_only_eval_fn(model, self.calib_data)
                observer_dict = {}
                get_observer_dict(model, observer_dict)

                self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
                                'observer is not recorded in the dict')
                self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()),
                                 2 * len(self.calib_data))
                self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0],
                                 model(self.calib_data[0][0]))
Example #12
0
    def test_fuse_module_eval(self):
        model = ModelForFusion(default_qconfig)
        model.eval()
        model = fuse_modules(
            model,
            [['conv3', 'bn3', 'relu4'], ['conv1', 'bn1', 'relu1'],
             ['conv2', 'relu2'], ['bn2', 'relu3'], ['sub1.conv', 'sub1.bn']])
        self.assertEqual(
            type(model.conv1),
            nni.ConvReLU2d,
            msg="Fused Conv + BN + Relu first layer (BN is folded)")
        self.assertEqual(type(model.conv1[0]),
                         nn.Conv2d,
                         msg="Fused Conv + BN + Relu (Conv + folded BN only)")
        self.assertEqual(type(model.conv1[1]),
                         nn.ReLU,
                         msg="Fused Conv + BN + Relu second layer (Relu only)")
        self.assertEqual(
            type(model.bn1),
            nn.Identity,
            msg="Fused Conv + BN + Relu second layer (Skipped BN)")
        self.assertEqual(
            type(model.relu1),
            nn.Identity,
            msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
        self.assertEqual(
            type(model.conv2),
            nni.ConvReLU3d,
            msg="Fused Conv + BN + Relu first layer (BN is folded)")
        self.assertEqual(type(model.bn2),
                         nni.BNReLU3d,
                         msg="Fused BN + Relu first layer (Relu is folded))")
        self.assertEqual(type(model.relu3),
                         nn.Identity,
                         msg="Fused BN + Relu second layer (Skipped Relu)")
        self.assertEqual(type(model.conv2[0]),
                         nn.Conv3d,
                         msg="Fused Conv + BN + Relu (Conv + folded BN only)")
        self.assertEqual(type(model.conv2[1]),
                         nn.ReLU,
                         msg="Fused Conv + BN + Relu second layer (Relu only)")
        self.assertEqual(
            type(model.relu2),
            nn.Identity,
            msg="Fused Conv + BN + Relu second layer (Skipped Relu)")

        self.assertEqual(type(model.conv3),
                         nni.ConvReLU1d,
                         msg="Fused Conv + Relu for Conv1d (folded BN)")
        self.assertEqual(type(model.conv3[0]),
                         nn.Conv1d,
                         msg="Fused Conv + Relu for Conv1d ")
        self.assertEqual(type(model.conv3[1]),
                         nn.ReLU,
                         msg="Fused Conv + Relu for Conv1d")
        self.assertEqual(type(model.bn3),
                         nn.Identity,
                         msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")

        self.assertEqual(type(model.sub1.conv),
                         nn.Conv2d,
                         msg="Fused submodule Conv + folded BN")
        self.assertEqual(type(model.sub1.bn),
                         nn.Identity,
                         msg="Fused submodule (skipped BN)")
        self.assertEqual(type(model.sub2.conv),
                         nn.Conv2d,
                         msg="Non-fused submodule Conv")
        self.assertEqual(type(model.sub2.relu),
                         torch.nn.ReLU,
                         msg="Non-fused submodule ReLU")

        model = prepare(model)
        self.checkObservers(model)
        test_only_eval_fn(model, self.img_data_1d)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv3), nniq.ConvReLU1d)
            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)
            self.assertEqual(type(model.bn2), nniq.BNReLU3d)
            test_only_eval_fn(model, self.img_data_1d)
            self.checkNoQconfig(model)

        checkQuantized(model)

        model = ModelForFusion(default_qconfig).eval()
        model = fuse_modules(
            model,
            [['conv1', 'bn1', 'relu1'], ['conv2', 'relu2'], ['bn2', 'relu3'],
             ['sub1.conv', 'sub1.bn'], ['conv3', 'bn3', 'relu4']])
        model = quantize(model, test_only_eval_fn, [self.img_data_1d])
        checkQuantized(model)
Example #13
0
def _sparse_layer_test_helper(
    model_class,
    sparse_mapping,
    ref_mapping,
    qconfig_dict,
    fqn_to_check,
    test_class,
    test_scripting,
):
    # SET UP TEST PARAMETERS, INPUTS AND WEIGHTS
    # ------------------------------------------
    batch_size = 12
    input_channels = 4
    output_channels = 7
    model = model_class(input_channels, output_channels)

    # For sparse kernels both the activation and weight ZP = 0
    X_scale = 0.2
    X_zp = 2
    W_scale = 1e-2
    W_zp = 0

    X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32)
    float_bias = torch.randn(output_channels, dtype=torch.float32)

    # generate a weight which we'll insert into the model
    W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32)
    mask = torch.randint(0, 2, W_fp32.shape)
    W_fp32 *= mask
    with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
        X_q = torch.quantize_per_tensor(X_fp32,
                                        scale=X_scale,
                                        zero_point=X_zp,
                                        dtype=torch.quint8)
        X_fp32 = X_q.dequantize()

        W_q = torch.quantize_per_tensor(W_fp32, W_scale, W_zp, torch.qint8)

        # PREPARE MODELS FOR QUANTIZATION
        # -------------------------------
        model.linear.weight = nn.Parameter(W_q.dequantize())
        model.eval()

        # Add `sparse_params` to the model. The test for correct
        # sparse_param addition is in the sparsifier tests
        model.linear.sparse_params = {"sparse_block_shape": (1, 4)}

        # generate model versions
        qmodel = copy.deepcopy(model)
        sqmodel = copy.deepcopy(model)

        # generate model versions and apply qconfigs
        tq.propagate_qconfig_(qmodel, qconfig_dict)
        tq.propagate_qconfig_(sqmodel, qconfig_dict)

        tq.prepare(qmodel, inplace=True)
        tq.prepare(sqmodel, inplace=True)

        # calibrate
        with torch.no_grad():
            qmodel(X_fp32)
            sqmodel(X_fp32)

        # ACTUAL TESTING BEGINS HERE
        # --------------------------

        # Make sure the quantization parameters are computed the same way
        qparams = qmodel.linear.qconfig.weight().calculate_qparams()
        sqparams = sqmodel.linear.qconfig.weight().calculate_qparams()
        test_class.assertEqual(qparams, sqparams)

        sqmodule_to_check = fqn_to_module(sqmodel, fqn_to_check)
        sqmodule_start_class = sqmodule_to_check.__class__
        sqmodule_expected_converted_class = sparse_mapping[
            sqmodule_start_class]

        qmodule_to_check = fqn_to_module(qmodel, fqn_to_check)
        qmodule_start_class = qmodule_to_check.__class__
        qmodule_expected_converted_class = ref_mapping[qmodule_start_class]

        # need to determine whether dynamic quantization is being performed since
        # input dtype will be different at the end
        is_dynamic = isinstance(qmodule_to_check.activation_post_process,
                                tq.PlaceholderObserver)

        tq.convert(sqmodel, inplace=True, mapping=sparse_mapping)
        tq.convert(qmodel, inplace=True, mapping=ref_mapping)

        # this code is a duplicate of above since the references do not
        # update to the post-convert modules
        sqmodule_to_check = fqn_to_module(sqmodel, fqn_to_check)
        qmodule_to_check = fqn_to_module(qmodel, fqn_to_check)

        # check that the modules were converted as expected
        assert isinstance(sqmodule_to_check,
                          sqmodule_expected_converted_class), "Convert failed"
        assert isinstance(qmodule_to_check,
                          qmodule_expected_converted_class), "Mapping failed"

        row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias(
        )[2:]
        assert row_block_size == 1 and col_block_size == 4

        # only run during serialization/deserialization tests
        # makes sure script/save/load doesn't malform the sqmodel
        if test_scripting:
            scripted_sqmodel = torch.jit.script(sqmodel)
            scripted_sqmodel.eval()
            buffer = io.BytesIO()
            torch.jit.save(scripted_sqmodel, buffer)
            buffer.seek(0)
            sqmodel = torch.jit.load(buffer)

        # use correct input dtype
        if is_dynamic:
            Y_ref = qmodel(X_fp32)
            Y_hat = sqmodel(X_fp32)
            test_class.assertEqual(Y_ref, Y_hat)
        else:
            Y_ref = qmodel(X_q)
            Y_hat = sqmodel(X_q)
            test_class.assertEqual(Y_ref.dequantize(), Y_hat.dequantize())
Example #14
0
    def test_sparse_qlinear_serdes(self):
        batch_size = 12
        input_channels = 4
        output_channels = 7
        model = self.SparseQuantizedModel(input_channels, output_channels)

        # For sparse kernels both the activation and weight ZP = 0
        X_scale = 0.2
        X_zp = 0
        W_scale = 1e-2
        W_zp = 0

        with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
            X_fp32 = torch.randn(batch_size,
                                 input_channels,
                                 dtype=torch.float32)
            float_bias = torch.randn(output_channels, dtype=torch.float32)

            X_q = torch.quantize_per_tensor(X_fp32,
                                            scale=X_scale,
                                            zero_point=X_zp,
                                            dtype=torch.quint8)
            X_fp32 = X_q.dequantize()

            W_fp32 = torch.randn(output_channels,
                                 input_channels,
                                 dtype=torch.float32)
            mask = torch.randint(0, 2, W_fp32.shape)
            W_fp32 *= mask
            W_q = torch.quantize_per_tensor(W_fp32, W_scale, W_zp, torch.qint8)

            model.linear.weight = nn.Parameter(W_q.dequantize())
            model.linear.sparse_params = {'sparse_block_shape': (1, 4)}
            model.eval()

            # Note: At the moment, for sparse kernels
            # fbgemm supports only static quantized sparse linear
            # qnnpack supports only dynamically quantized sparse linear
            # Hence we have two different tests.
            # fbgemm tests static flow, qnnpack tests dynamic.
            # Should be unified later on and tests should be fixed
            # appropriately.
            if qengine_is_fbgemm():
                model.qconfig = tq.get_default_qconfig('fbgemm')
                qmodel = copy.deepcopy(model)
                sqmodel = copy.deepcopy(model)

                tq.prepare(qmodel, inplace=True)
                tq.prepare(sqmodel, inplace=True)

                with torch.no_grad():
                    qmodel(X_fp32)
                    sqmodel(X_fp32)

                # Make sure the quantization parameters are computed the same way
                qparams = qmodel.linear.qconfig.weight().calculate_qparams()
                sqparams = sqmodel.linear.qconfig.weight().calculate_qparams()
                self.assertEqual(qparams, sqparams)

                # Make sure mapping of sparse kernels does not affect the non-sparse
                sparse_mapping = tq.get_default_static_quant_module_mappings()
                sparse_mapping[nn.Linear] = ao_nn_sq.Linear
                tq.convert(sqmodel, inplace=True, mapping=sparse_mapping)
                tq.convert(qmodel, inplace=True)

                assert isinstance(sqmodel.linear,
                                  ao_nn_sq.Linear), "Convert failed"
                assert isinstance(qmodel.linear,
                                  nn.quantized.Linear), "Mapping failed"

                scripted_sqmodel = torch.jit.script(sqmodel)
                scripted_sqmodel.eval()
                buffer = io.BytesIO()
                torch.jit.save(scripted_sqmodel, buffer)
                buffer.seek(0)
                sqmodel = torch.jit.load(buffer)

                # Make sure numerics are right
                Y_ref = qmodel(X_q)
                Y_hat = sqmodel(X_q)
                self.assertEqual(Y_ref.dequantize(), Y_hat.dequantize())

            elif qengine_is_qnnpack():
                qconfig = {nn.Linear: tq.qconfig.default_dynamic_qconfig}
                dqmodel = copy.deepcopy(model)
                sdqmodel = copy.deepcopy(model)

                tq.propagate_qconfig_(dqmodel, qconfig)
                tq.propagate_qconfig_(sdqmodel, qconfig)

                # Make sure the quantization parameters are computed the same way
                qparams = dqmodel.linear.qconfig.weight().calculate_qparams()
                sqparams = sdqmodel.linear.qconfig.weight().calculate_qparams()
                self.assertEqual(qparams, sqparams)

                # Make sure mapping of sparse kernels does not affect the non-sparse
                sparse_mapping = copy.deepcopy(
                    tq.get_default_dynamic_quant_module_mappings())
                sparse_mapping[nn.Linear] = ao_nn_sq.dynamic.Linear
                with LinearBlockSparsePattern(1, 4):
                    tq.convert(sdqmodel, inplace=True, mapping=sparse_mapping)
                tq.convert(
                    dqmodel,
                    mapping=tq.get_default_dynamic_quant_module_mappings(),
                    inplace=True)

                assert isinstance(sdqmodel.linear,
                                  ao_nn_sq.dynamic.Linear), "Convert failed"
                assert isinstance(
                    dqmodel.linear,
                    nn.quantized.dynamic.Linear), "Mapping failed"

                scripted_sdqmodel = torch.jit.script(sdqmodel)
                scripted_sdqmodel.eval()
                buffer = io.BytesIO()
                torch.jit.save(scripted_sdqmodel, buffer)
                buffer.seek(0)
                sdqmodel = torch.jit.load(buffer)

                # Make sure numerics are right
                Y_ref = dqmodel(X_fp32)
                Y_hat = sdqmodel(X_fp32)
                self.assertEqual(Y_ref, Y_hat)