예제 #1
0
    def test_input_weight_equalization_results(self):
        """ Tests that for small models, the results of quantized models that
        have been equalized are very close to models that have not been equalized.
        """

        tests = [
            SingleLayerLinearModel, TwoLayerLinearModel, LinearAddModel,
            SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel
        ]

        x = torch.rand((5, 5))
        for M in tests:
            m = M().eval()

            # No equalization
            prepared = prepare_fx(copy.deepcopy(m),
                                  specific_qconfig_dict,
                                  equalization_qconfig_dict={})
            prepared(x)
            quantized = convert_fx(prepared)  # Check if compile
            quantized_output = quantized(x)

            # With equalization
            prepared = prepare_fx(
                copy.deepcopy(m),
                specific_qconfig_dict,
                equalization_qconfig_dict=default_equalization_qconfig_dict)
            prepared(x)
            equalized_and_quantized = convert_fx(prepared)  # Check if compile
            equalized_and_quantized_output = equalized_and_quantized(x)
            self.assertEqual(quantized_output,
                             equalized_and_quantized_output,
                             rtol=1e-5,
                             atol=0.1)
    def _test_quantize_model(self, model_config):
        if get_torch_version() >= [1, 11]:
            import torch.ao.quantization as tq
            from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
        else:
            import torch.quantization as tq
            from torch.quantization.quantize_fx import convert_fx, prepare_fx

        # quantize model
        model = build_model(model_config)
        model.eval()

        input = torch.ones([1, 3, 32, 32])

        heads = model.get_heads()
        # since prepare changes the code of ClassyBlock we need to clear head first
        # and reattach it later to avoid caching
        model.clear_heads()

        prepare_custom_config_dict = {}
        head_path_from_blocks = [
            _find_block_full_path(model.features, block_name)
            for block_name in heads.keys()
        ]
        # we need to keep the modules used in head standalone since
        # it will be accessed with path name directly in execution
        prepare_custom_config_dict["standalone_module_name"] = [(
            head,
            {
                "": tq.default_qconfig
            },
            {
                "input_quantized_idxs": [0],
                "output_quantized_idxs": []
            },
            None,
        ) for head in head_path_from_blocks]
        model.initial_block = prepare_fx(model.initial_block,
                                         {"": tq.default_qconfig})
        model.features = prepare_fx(
            model.features,
            {"": tq.default_qconfig},
            prepare_custom_config_dict,
        )
        model.set_heads(heads)

        # calibration
        model(input)

        heads = model.get_heads()
        model.clear_heads()
        model.initial_block = convert_fx(model.initial_block)
        model.features = convert_fx(model.features)
        model.set_heads(heads)

        output = model(input)
        self.assertEqual(output.size(), (1, 1000))
예제 #3
0
def default_rcnn_prepare_for_quant_convert(self, cfg):
    if cfg.QUANTIZATION.EAGER_MODE:
        convert(self, inplace=True)
    else:
        assert not isinstance(self.backbone,
                              FPN), "FPN is not supported in FX mode"
        self.backbone = convert_fx(
            self.backbone,
            convert_custom_config_dict={
                "preserved_attributes":
                ["size_divisibility", "padding_constraints"]
            },
        )
        self.proposal_generator.rpn_head.rpn_feature = convert_fx(
            self.proposal_generator.rpn_head.rpn_feature)
        self.proposal_generator.rpn_head.rpn_regressor.cls_logits = convert_fx(
            self.proposal_generator.rpn_head.rpn_regressor.cls_logits)
        self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = convert_fx(
            self.proposal_generator.rpn_head.rpn_regressor.bbox_pred)
        self.roi_heads.box_head.roi_box_conv = convert_fx(
            self.roi_heads.box_head.roi_box_conv)
        self.roi_heads.box_head.avgpool = convert_fx(
            self.roi_heads.box_head.avgpool)
        self.roi_heads.box_predictor.cls_score = convert_fx(
            self.roi_heads.box_predictor.cls_score)
        self.roi_heads.box_predictor.bbox_pred = convert_fx(
            self.roi_heads.box_predictor.bbox_pred)
    return self
    def test_embedding(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.emb = torch.nn.Embedding(num_embeddings=10,
                                              embedding_dim=12)

            def forward(self, indices):
                return self.emb(indices)

        model = M().eval()
        indices = torch.randint(low=0, high=10, size=(20, ))

        quantized_node = ns.call_module(nnq.Embedding)
        configs = [
            (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)),
            (None, ns.call_module(nn.Embedding)),
            (default_qconfig, ns.call_module(nn.Embedding)),
        ]

        for qconfig, node in configs:
            qconfig_dict = {"": qconfig}
            m = prepare_fx(model, qconfig_dict)
            m = convert_fx(m)
            self._compare_script_and_mobile(m, input=indices)
예제 #5
0
    def test_linear(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
            weight=torch.quantization.default_weight_observer)
        prepared = prepare_fx(
            m, {"": qconfig},
            backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(linear_module_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, linear_module_input,
                               [((1, *linear_module_input.shape[1:]),
                                 (5, *linear_module_input.shape[1:]),
                                 (10, *linear_module_input.shape[1:]))])
        # make sure it runs
        trt_mod(linear_module_input.cuda())
예제 #6
0
    def test_conv(self):
        class Conv2d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.conv = torch.nn.Conv2d(*args)

            def forward(self, x):
                return self.conv(x)

        conv2d_input = torch.rand(1, 3, 224, 224)
        conv2d_module_args = (3, 3, 3)

        m = Conv2d(*conv2d_module_args).eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
            weight=torch.quantization.default_weight_observer)
        prepared = prepare_fx(
            m, {"": qconfig},
            backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(conv2d_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, conv2d_input,
                               [((1, 3, 224, 224), (5, 3, 224, 224),
                                 (10, 3, 224, 224))])
        # make sure it runs
        trt_mod(conv2d_input.cuda())
예제 #7
0
def build_int8_trt_implicit_quant(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # Quantization
    qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, reduce_range=True),
        weight=torch.ao.quantization.default_per_channel_weight_observer)
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared)
    ref_res = quantized_rn18(data)

    # Build trt int8 model
    traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
    shape_prop.ShapeProp(traced_rn18).propagate(data)
    traced_rn18 = NormalizeArgs(traced_rn18).transform()
    interp = TRTInterpreter(traced_rn18,
                            InputTensorSpec.from_tensors([data]),
                            logger_level=trt.Logger.VERBOSE)
    engine, input_names, output_names = interp.run(
        fp16_mode=False, int8_mode=True, strict_type_constraints=True)
    trt_mod = TRTModule(engine, input_names, output_names)
    trt_res = trt_mod(data.cuda())
    print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
def build_int8_trt(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # data = torch.randn(1, 32)
    # data = torch.randn(1, 64, 10, 10)
    # TensorRT only supports symmetric quantization
    qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
        ),
        # weight=torch.ao.quantization.default_weight_observer
        # uncomment to check per channel quant works
        weight=torch.quantization.default_per_channel_weight_observer
    )
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared, is_reference=True)
    ref_res = quantized_rn18(data)
    print("quantized model:", quantized_rn18)

    quantized_rn18 = acc_tracer.trace(quantized_rn18, [data])  # type: ignore[assignment]
    interp = TRTInterpreter(
        quantized_rn18,
        [InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float,
                         shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)],
        explicit_batch_dimension=True, explicit_precision=True, logger_level=trt.Logger.VERBOSE)
    interpreter_result = interp.run(fp16_mode=False, int8_mode=True)
    trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
    trt_res = trt_mod(data.cuda())
    print("explicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
예제 #9
0
    def test_q_prep_fx_before_s_prep(self):
        r"""
        This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask between sparse prepare and convert_fx. This also tests the
        automatic fusion that occurs during prepare_fx.
        """
        (
            mod,
            sparsifier,
            _,
        ) = _get_model_and_sparsifier_and_sparse_config()

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qconfig("fbgemm")
        qconfig_mapping = tq.QConfigMapping() \
            .set_module_name("4", qconfig) \
            .set_module_name("5", qconfig)


        mod = prepare_fx(mod, qconfig_mapping, (example,))

        # its absolutely broken by auto fusion in fx
        # but will still work if you put the correct fqn in
        sparse_config = [
            {
                "tensor_fqn": "5.0.weight",
                "sparsity_level": 0.7,
                "sparse_block_shape": (1, 4),
                "zeros_per_block": 4,
            },
            {"tensor_fqn": "0.0.weight"},
        ]
        sparsifier.prepare(mod, config=sparse_config)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        mod(example)
        mod = convert_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU))
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
예제 #10
0
    def test_input_weight_equalization_convert(self):
        """ Tests that the modified model for equalization (before quantization)
        returns the same output as the original model
        """

        tests = [(SingleLayerLinearModel, 2), (LinearAddModel, 2),
                 (TwoLayerLinearModel, 2),
                 (SingleLayerFunctionalLinearModel, 2),
                 (FunctionalLinearAddModel, 2),
                 (TwoLayerFunctionalLinearModel, 2), (LinearReluModel, 2),
                 (LinearReluLinearModel, 2), (LinearReluAddModel, 2),
                 (FunctionalLinearReluModel, 2),
                 (FunctionalLinearReluLinearModel, 2), (ConvModel, 4),
                 (TwoLayerConvModel, 4), (SingleLayerFunctionalConvModel, 4),
                 (TwoLayerFunctionalConvModel, 4), (ConvReluModel, 4),
                 (ConvReluConvModel, 4), (ConvReluAddModel, 4),
                 (FunctionalConvReluModel, 4),
                 (FunctionalConvReluConvModel, 4)]

        for (M, ndim) in tests:
            m = M().eval()

            if ndim == 2:
                x = torch.rand((5, 5))
            elif ndim == 4:
                x = torch.rand((16, 3, 224, 224))

            example_inputs = (x, )
            prepared = prepare_fx(
                copy.deepcopy(m),
                specific_qconfig_dict,
                example_inputs=example_inputs,
                _equalization_config=default_equalization_qconfig_dict)
            output = prepared(x)

            convert_ref = _convert_equalization_ref(prepared)
            convert_ref_output = convert_ref(x)

            prepared = prepare_fx(
                m,
                specific_qconfig_dict,
                example_inputs=example_inputs,
                _equalization_config=default_equalization_qconfig_dict)
            prepared(x)
            convert_fx(prepared)  # Check if compile
            self.assertEqual(output, convert_ref_output)
예제 #11
0
 def _do_quant_transforms(
     m: torch.nn.Module,
     input_tensor: torch.Tensor,
 ) -> torch.nn.Module:
     # do the quantizaton transforms and save result
     qconfig = torch.quantization.get_default_qconfig('fbgemm')
     mp = quantize_fx.prepare_fx(m, {'': qconfig})
     mp(input_tensor)
     mq = quantize_fx.convert_fx(mp)
     return mq
예제 #12
0
    def test_s_prep_before_qat_prep_fx(self):
        r"""
        This test checks that the ordering of sparse prepare -> prepare_qat_fx -> convert_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask before convert_fx.
        """
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config()
        sparsifier.prepare(mod, config=sparse_config)

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qat_qconfig("fbgemm")
        qconfig_mapping = tq.QConfigMapping() \
            .set_module_name("4", qconfig) \
            .set_module_name("5", qconfig)
        mod = prepare_qat_fx(mod, qconfig_mapping, (example,))

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.qat.LinearReLU))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.weight"))
        mod(example)
        mod = convert_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU))
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
    def test_conv2d(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        m = M().eval()
        qconfig_dict = {"": default_qconfig, "module_name": [("conv1", None)]}
        m = prepare_fx(m, qconfig_dict)
        data = torch.randn(1, 1, 1, 1)
        m = convert_fx(m)
        # first conv is quantized, second conv is not quantized
        self._compare_script_and_mobile(m, input=data)
예제 #14
0
    def test_quantize_model(self, config):
        """
        Test that the model builds using a config using either model_params or
        model_name and calls fx graph mode quantization apis
        """
        if get_torch_version() < [1, 8]:
            self.skipTest(
                "FX Graph Modee Quantization is only availablee from 1.8")
        if get_torch_version() >= [1, 11]:
            import torch.ao.quantization as tq
            from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
        else:
            import torch.quantization as tq
            from torch.quantization.quantize_fx import convert_fx, prepare_fx

        model = build_model(config)
        assert isinstance(model, RegNet)
        model.eval()
        model.stem = prepare_fx(model.stem, {"": tq.default_qconfig})
        model.stem = convert_fx(model.stem)
    def test_submodule(self):
        # test quantizing complete module, submodule and linear layer
        configs = [
            {},
            {
                "module_name": [("subm", None)]
            },
            {
                "module_name": [("fc", None)]
            },
        ]
        for config in configs:
            model = LinearModelWithSubmodule().eval()
            qconfig_dict = {
                "": torch.ao.quantization.get_default_qconfig("qnnpack"),
                **config,
            }
            model = prepare_fx(model, qconfig_dict)
            quant = convert_fx(model)

            x = torch.randn(5, 5)
            self._compare_script_and_mobile(quant, input=x)
예제 #16
0
def convert_predictor(
    cfg,
    pytorch_model,
    predictor_type,
    data_loader,
):
    if "int8" in predictor_type:
        if not cfg.QUANTIZATION.QAT.ENABLED:
            logger.info(
                "The model is not quantized during training, running post"
                " training quantization ...")

            pytorch_model = post_training_quantize(cfg, pytorch_model,
                                                   data_loader)
            # only check bn exists in ptq as qat still has bn inside fused ops
            if fuse_utils.check_bn_exist(pytorch_model):
                logger.warn(
                    "Post training quantized model has bn inside fused ops")
        logger.info(
            f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")

        if hasattr(pytorch_model, "prepare_for_quant_convert"):
            pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
        else:
            # TODO(T93870381): move this to a default function
            if cfg.QUANTIZATION.EAGER_MODE:
                pytorch_model = convert(pytorch_model, inplace=False)
            else:  # FX graph mode quantization
                pytorch_model = convert_fx(pytorch_model)

        logger.info("Quantized Model:\n{}".format(pytorch_model))
    else:
        pytorch_model = fuse_utils.fuse_model(pytorch_model)
        logger.info("Fused Model:\n{}".format(pytorch_model))
        if fuse_utils.count_bn_exist(pytorch_model) > 0:
            logger.warning("BN existed in pytorch model after fusing.")
    return pytorch_model
예제 #17
0
 def prepare_for_quant_convert(self, cfg):
     self.avgpool = convert_fx(
         self.avgpool,
         convert_custom_config_dict=self.custom_config_dict)
     return self
예제 #18
0
 def prepare_for_quant_convert(self, cfg):
     self.avgpool = convert_fx(self.avgpool)
     return self
예제 #19
0
    def test_selective_equalization(self):
        """ Tests that we are able to run numeric suite on the equalized model
        and construct a valid equalization_qconfig_dict equalizing only the top
        4 layers with the highest quantization errors.
        """

        torch.manual_seed(1)

        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5))
                self.top = torch.nn.Sequential(torch.nn.Linear(5, 5))

            def forward(self, x):
                x = self.bot(x)
                x = torch.add(x, 5)
                x = self.top(x)
                return x

        float_model = M().eval()
        # Hard coded so that the top layer has a higher quantization error
        x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957],
                          [0.8373, 0.8851, 0.8229, 0.0212, 0.8987],
                          [0.9077, 0.7538, 0.4530, 0.5772, 0.1376],
                          [0.0690, 0.9002, 0.7998, 0.2768, 0.8985],
                          [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]])

        # Quantize the float model
        prepared_model = prepare_fx(copy.deepcopy(float_model),
                                    specific_qconfig_dict)
        prepared_model(x)
        quantized_model = convert_fx(copy.deepcopy(prepared_model))

        # Get the SQNR between the float and quantized model
        layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model),
                                                 quantized_model, x)

        # Construct the equalization_qconfig_dict equalizing layers with the highest
        # quantization errors
        selective_equalization_qconfig_dict = get_equalization_qconfig_dict(
            layer_to_sqnr_dict, 1)

        # Create the selectively equalized model
        prepared_model = prepare_fx(
            copy.deepcopy(float_model),
            specific_qconfig_dict,
            equalization_qconfig_dict=selective_equalization_qconfig_dict,
        )
        prepared_model(x)
        equalized_model = convert_fx(prepared_model)

        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        # Check the order of nodes in the graph
        self.checkGraphModuleNodes(equalized_model,
                                   expected_node_list=node_list)
예제 #20
0
    def test_input_weight_equalization_graphs(self):
        """ Tests that the modified model for equalization has the same graph
        structure as the model without equalization (before and after
        quantization).
        """

        linear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        linearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_method('dequantize')
        ]

        linearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_method('dequantize')
        ]

        functionalLinearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        conv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        conv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        convRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_method('dequantize')
        ]

        convReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConvRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_method('dequantize')
        ]

        functionalConvReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        tests = [
            (SingleLayerLinearModel, linear_node_list),
            (LinearAddModel, linearAdd_node_list),
            (TwoLayerLinearModel, linear2_node_list),
            (SingleLayerFunctionalLinearModel, functionalLinear_node_list),
            (FunctionalLinearAddModel, functionalLinearAdd_node_list),
            (TwoLayerFunctionalLinearModel, functionalLinear2_node_list),
            (LinearReluModel, linearRelu_node_list),
            (LinearReluLinearModel, linearReluLinear_node_list),
            (FunctionalLinearReluModel, functionalLinearRelu_node_list),
            (FunctionalLinearReluLinearModel,
             functionalLinearReluLinear_node_list),
            (ConvModel, conv_node_list), (TwoLayerConvModel, conv2_node_list),
            (SingleLayerFunctionalConvModel, functionalConv_node_list),
            (TwoLayerFunctionalConvModel, functionalConv2_node_list),
            (ConvReluModel, convRelu_node_list),
            (ConvReluConvModel, convReluConv_node_list),
            (FunctionalConvReluModel, functionalConvRelu_node_list),
            (FunctionalConvReluConvModel, functionalConvReluConv_node_list)
        ]

        for (M, node_list) in tests:
            m = M().eval()
            prepared = prepare_fx(
                m,
                specific_qconfig_dict,
                equalization_qconfig_dict=default_equalization_qconfig_dict)
            equalized_quantized_model = convert_fx(prepared)

            # Check the order of nodes in the graph
            self.checkGraphModuleNodes(equalized_quantized_model,
                                       expected_node_list=node_list)
예제 #21
0
def default_prepare_for_quant_convert(cfg, model):
    return convert_fx(model)
 def convert(self, model, submodules, attrs):
     model.another_layer = convert_fx(model.another_layer)
     return model