Exemplo n.º 1
0
    def test_input_weight_equalization_results(self):
        """ Tests that for small models, the results of quantized models that
        have been equalized are very close to models that have not been equalized.
        """

        tests = [
            SingleLayerLinearModel, TwoLayerLinearModel, LinearAddModel,
            SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel
        ]

        x = torch.rand((5, 5))
        for M in tests:
            m = M().eval()

            # No equalization
            prepared = prepare_fx(copy.deepcopy(m),
                                  specific_qconfig_dict,
                                  equalization_qconfig_dict={})
            prepared(x)
            quantized = convert_fx(prepared)  # Check if compile
            quantized_output = quantized(x)

            # With equalization
            prepared = prepare_fx(
                copy.deepcopy(m),
                specific_qconfig_dict,
                equalization_qconfig_dict=default_equalization_qconfig_dict)
            prepared(x)
            equalized_and_quantized = convert_fx(prepared)  # Check if compile
            equalized_and_quantized_output = equalized_and_quantized(x)
            self.assertEqual(quantized_output,
                             equalized_and_quantized_output,
                             rtol=1e-5,
                             atol=0.1)
Exemplo n.º 2
0
    def test_input_weight_equalization_branching(self):
        """ Tests that graphs containing branches are prepared correctly.
        Specifically, equalization observers should not be inserted in front of
        branches in which both initial layers in the branches plan to be
        quantized.
        """

        # Tests that we do not add an equalization observer due to both initial
        # nodes in the branch containing layers that need to be equalized.
        # Note that this should print out 2 warning messages for not being able
        # to equalize layers linear1 and linear1 because it is part of a branch
        class TestBranchingWithoutEqualizationModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = nn.Linear(5, 5)
                self.linear2 = nn.Linear(5, 5)

            def forward(self, x):
                y = self.linear1(x)
                z = self.linear2(x)
                return torch.add(y, z)

        no_eq_branching_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 0,
            ns.call_module(MinMaxObserver): 3,
        }

        m = TestBranchingWithoutEqualizationModel().eval()
        prepared = prepare_fx(
            m,
            specific_qconfig_dict,
            equalization_qconfig_dict=default_equalization_qconfig_dict)
        self.checkGraphModuleNodes(
            prepared, expected_node_occurrence=no_eq_branching_node_occurrence)

        # Tests that we will add an equalization observer because there is only
        # one initial node in the branch that needs to be equalized
        class TestBranchingWithEqualizationModel(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = nn.Linear(5, 5)

            def forward(self, x):
                y = self.linear1(x)
                z = torch.add(x, 5)
                return torch.add(y, z)

        eq_branching_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 1,
            ns.call_module(MinMaxObserver): 2,
        }

        m = TestBranchingWithEqualizationModel().eval()
        prepared = prepare_fx(
            m,
            specific_qconfig_dict,
            equalization_qconfig_dict=default_equalization_qconfig_dict)
        self.checkGraphModuleNodes(
            prepared, expected_node_occurrence=eq_branching_node_occurrence)
    def _test_quantize_model(self, model_config):
        if get_torch_version() >= [1, 11]:
            import torch.ao.quantization as tq
            from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
        else:
            import torch.quantization as tq
            from torch.quantization.quantize_fx import convert_fx, prepare_fx

        # quantize model
        model = build_model(model_config)
        model.eval()

        input = torch.ones([1, 3, 32, 32])

        heads = model.get_heads()
        # since prepare changes the code of ClassyBlock we need to clear head first
        # and reattach it later to avoid caching
        model.clear_heads()

        prepare_custom_config_dict = {}
        head_path_from_blocks = [
            _find_block_full_path(model.features, block_name)
            for block_name in heads.keys()
        ]
        # we need to keep the modules used in head standalone since
        # it will be accessed with path name directly in execution
        prepare_custom_config_dict["standalone_module_name"] = [(
            head,
            {
                "": tq.default_qconfig
            },
            {
                "input_quantized_idxs": [0],
                "output_quantized_idxs": []
            },
            None,
        ) for head in head_path_from_blocks]
        model.initial_block = prepare_fx(model.initial_block,
                                         {"": tq.default_qconfig})
        model.features = prepare_fx(
            model.features,
            {"": tq.default_qconfig},
            prepare_custom_config_dict,
        )
        model.set_heads(heads)

        # calibration
        model(input)

        heads = model.get_heads()
        model.clear_heads()
        model.initial_block = convert_fx(model.initial_block)
        model.features = convert_fx(model.features)
        model.set_heads(heads)

        output = model(input)
        self.assertEqual(output.size(), (1, 1000))
Exemplo n.º 4
0
    def test_linear(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
            weight=torch.quantization.default_weight_observer)
        prepared = prepare_fx(
            m, {"": qconfig},
            backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(linear_module_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, linear_module_input,
                               [((1, *linear_module_input.shape[1:]),
                                 (5, *linear_module_input.shape[1:]),
                                 (10, *linear_module_input.shape[1:]))])
        # make sure it runs
        trt_mod(linear_module_input.cuda())
Exemplo n.º 5
0
    def test_addmm(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.weight = torch.randn(5, 5)
                self.bias = torch.randn(5)

            def forward(self, x):
                return torch.addmm(self.bias, x, self.weight)

        m = M().eval()
        prepared = prepare_fx(
            m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # weight
            ns.call_module(torch.ao.quantization.MinMaxObserver): 1,
            # activation
            ns.call_module(torch.ao.quantization.HistogramObserver): 2,
        }
        self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
        quantized = _convert_fx_do_not_use(
            prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
        node_occurrence = {
            # input activation, output activation and weight
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_function(torch.addmm): 1,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
Exemplo n.º 6
0
    def test_ops(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.linear = torch.nn.Linear(5, 5)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv(x)
                x = self.linear(x)
                x = x + 3
                x = self.relu(x)
                x = x + 6
                return x

        m = M().eval()
        m = prepare_fx(m, {"": default_qconfig})
        m = _convert_fx_do_not_use(m, is_reference=True)
        expected_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 5,
            ns.call_method("dequantize"): 5,
            ns.call_module(torch.nn.quantized._reference.Linear): 1,
            ns.call_module(torch.nn.quantized._reference.Conv2d): 1,
        }
        self.checkGraphModuleNodes(
            m, expected_node_occurrence=expected_occurrence)
def build_int8_trt(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # data = torch.randn(1, 32)
    # data = torch.randn(1, 64, 10, 10)
    # TensorRT only supports symmetric quantization
    qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
        ),
        # weight=torch.ao.quantization.default_weight_observer
        # uncomment to check per channel quant works
        weight=torch.quantization.default_per_channel_weight_observer
    )
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared, is_reference=True)
    ref_res = quantized_rn18(data)
    print("quantized model:", quantized_rn18)

    quantized_rn18 = acc_tracer.trace(quantized_rn18, [data])  # type: ignore[assignment]
    interp = TRTInterpreter(
        quantized_rn18,
        [InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float,
                         shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)],
        explicit_batch_dimension=True, explicit_precision=True, logger_level=trt.Logger.VERBOSE)
    interpreter_result = interp.run(fp16_mode=False, int8_mode=True)
    trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
    trt_res = trt_mod(data.cuda())
    print("explicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
Exemplo n.º 8
0
    def test_unsupported_qconfig(self):
        """ Check that we won't quantize the model if the qconfig is not supported
        """
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        trt_unsupported_qconfig = default_qconfig
        prepared = prepare_fx(m, {"": trt_unsupported_qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        # calibration
        prepared(linear_module_input)
        quantized = _convert_fx_do_not_use(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 0,
            ns.call_method("dequantize"): 0,
            ns.call_module(torch.nn.Linear): 1,
            ns.call_module(torch.nn.quantized._reference.Linear): 0,
        }
        # check model is not quantized
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
Exemplo n.º 9
0
    def test_input_weight_equalization_equalization_scales(self):
        """ After applying the equalization functions, check if the equalization
        scales are the expected values
        """

        tests = [
            SingleLayerLinearModel, TwoLayerLinearModel,
            SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel
        ]

        x = torch.rand((5, 5))
        for M in tests:
            m = M().eval()
            exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())

            prepared = prepare_fx(
                m,
                specific_qconfig_dict,
                equalization_qconfig_dict=default_equalization_qconfig_dict)
            prepared(x)
            convert_ref = _convert_equalization_ref(prepared)
            convert_ref(x)

            counter = 0
            for node in convert_ref.graph.nodes:
                if 'equalization_scale' in node.name and node.op == 'get_attr':
                    self.assertEqual(
                        convert_ref.get_buffer(str(node.target)).reshape(-1),
                        exp_eq_scales[counter])
                    counter += 1
Exemplo n.º 10
0
    def test_conv(self):
        class Conv2d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.conv = torch.nn.Conv2d(*args)

            def forward(self, x):
                return self.conv(x)

        conv2d_input = torch.rand(1, 3, 224, 224)
        conv2d_module_args = (3, 3, 3)

        m = Conv2d(*conv2d_module_args).eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
            weight=torch.quantization.default_weight_observer)
        prepared = prepare_fx(
            m, {"": qconfig},
            backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(conv2d_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, conv2d_input,
                               [((1, 3, 224, 224), (5, 3, 224, 224),
                                 (10, 3, 224, 224))])
        # make sure it runs
        trt_mod(conv2d_input.cuda())
    def test_embedding(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.emb = torch.nn.Embedding(num_embeddings=10,
                                              embedding_dim=12)

            def forward(self, indices):
                return self.emb(indices)

        model = M().eval()
        indices = torch.randint(low=0, high=10, size=(20, ))

        quantized_node = ns.call_module(nnq.Embedding)
        configs = [
            (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)),
            (None, ns.call_module(nn.Embedding)),
            (default_qconfig, ns.call_module(nn.Embedding)),
        ]

        for qconfig, node in configs:
            qconfig_dict = {"": qconfig}
            m = prepare_fx(model, qconfig_dict)
            m = convert_fx(m)
            self._compare_script_and_mobile(m, input=indices)
Exemplo n.º 12
0
    def test_input_weight_equalization_activation_values(self):
        """ After applying the equalization functions check if the input
        observer's min/max values are as expected
        """

        tests = [
            SingleLayerLinearModel, TwoLayerLinearModel,
            SingleLayerFunctionalLinearModel
        ]

        x = torch.rand((5, 5))
        torch.manual_seed(0)
        for M in tests:
            m = M().eval()
            exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
            exp_weights, exp_bias = self.get_expected_weights_bias(
                m,
                x.detach().numpy(), exp_eq_scales)
            exp_inp_act_vals = self.get_expected_inp_act_vals(
                m, x, exp_eq_scales, exp_weights, exp_bias)
            exp_weight_act_vals = self.get_expected_weight_act_vals(
                exp_weights)

            example_inputs = (x, )
            prepared = prepare_fx(
                m,
                specific_qconfig_dict,
                example_inputs=example_inputs,
                _equalization_config=default_equalization_qconfig_dict)
            prepared(x)
            convert_ref = _convert_equalization_ref(prepared)
            convert_ref(x)

            modules = dict(convert_ref.named_modules(remove_duplicate=False))
            inp_counter = 0
            weight_counter = 0
            for node in convert_ref.graph.nodes:
                users = list(node.users)
                if node.op == 'call_module' and isinstance(
                        modules[str(node.target)], MinMaxObserver):
                    if len(users) == 1 and users[
                            0].target == torch.nn.functional.linear and users[
                                0].args[1] == node:
                        # Check min/max values of weight activation layers
                        exp_min_val, exp_max_val = exp_weight_act_vals[
                            weight_counter]
                        self.assertEqual(modules[str(node.target)].min_val,
                                         exp_min_val)
                        self.assertEqual(modules[str(node.target)].max_val,
                                         exp_max_val)
                        weight_counter += 1
                    else:
                        # Check min/max values of input activation layers
                        exp_min_val, exp_max_val = exp_inp_act_vals[
                            inp_counter]
                        self.assertEqual(modules[str(node.target)].min_val,
                                         exp_min_val)
                        self.assertEqual(modules[str(node.target)].max_val,
                                         exp_max_val)
                        inp_counter += 1
Exemplo n.º 13
0
def build_int8_trt_implicit_quant(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # Quantization
    qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, reduce_range=True),
        weight=torch.ao.quantization.default_per_channel_weight_observer)
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared)
    ref_res = quantized_rn18(data)

    # Build trt int8 model
    traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
    shape_prop.ShapeProp(traced_rn18).propagate(data)
    traced_rn18 = NormalizeArgs(traced_rn18).transform()
    interp = TRTInterpreter(traced_rn18,
                            InputTensorSpec.from_tensors([data]),
                            logger_level=trt.Logger.VERBOSE)
    engine, input_names, output_names = interp.run(
        fp16_mode=False, int8_mode=True, strict_type_constraints=True)
    trt_mod = TRTModule(engine, input_names, output_names)
    trt_res = trt_mod(data.cuda())
    print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
Exemplo n.º 14
0
 def _test_module(self,
                  m,
                  inputs,
                  shape_ranges,
                  no_prepare=None,
                  no_convert=None):
     """
     Args:
       m: the float module we want to test
       inputs: list of inputs for the module
       shape_ranges: a list of shape_range, where every shape_range is a tuple of
       three tuples
       ((min_input_shape), (optimized_input_shape), (max_input_shape)).
       Each shape_range is used to populate a TensorRT optimization profile.
       e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize
       for (25, 224) because it's the most common input shape, then we set shape_ranges to
       ((1, 224), (25, 225), (100, 224))
       no_prepare: node occurrence after prepare
       no_convert: node occurrence after convert
     """
     m = m.eval()
     prepared = prepare_fx(m, {"": self.qconfig},
                           backend_config_dict=self.trt_backend_config_dict)
     self.checkGraphModuleNodes(prepared,
                                expected_node_occurrence=no_prepare)
     # calibration
     prepared(*inputs)
     quantized = _convert_fx_do_not_use(prepared, is_reference=True)
     self.checkGraphModuleNodes(quantized,
                                expected_node_occurrence=no_convert)
     # lower to trt
     trt_mod = lower_to_trt(quantized, inputs, shape_ranges)
     inputs_cuda = [i.cuda() for i in inputs]
     # make sure it runs
     trt_mod(*inputs_cuda)
Exemplo n.º 15
0
    def test_input_weight_equalization_prepare(self):
        """ Tests that graphs created after prepare_fx is as expected
        """

        single_nn_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 1,
            ns.call_module(MinMaxObserver): 2,
        }

        two_nn_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 2,
            ns.call_module(MinMaxObserver): 3,
        }

        single_F_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 1,
            ns.call_module(_WeightEqualizationObserver): 1,
            ns.call_module(MinMaxObserver): 3,
        }

        two_F_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 2,
            ns.call_module(_WeightEqualizationObserver): 2,
            ns.call_module(MinMaxObserver): 5,
        }

        fp_F_layer_node_occurrence = {
            ns.call_module(_InputEqualizationObserver): 2,
            ns.call_module(_WeightEqualizationObserver): 2,
            ns.call_module(MinMaxObserver): 6,
        }

        tests = [(SingleLayerLinearModel, single_nn_layer_node_occurrence),
                 (TwoLayerLinearModel, two_nn_layer_node_occurrence),
                 (TwoLayerFunctionalLinearModel, two_F_layer_node_occurrence),
                 (FunctionalLinearAddModel, fp_F_layer_node_occurrence),
                 (LinearReluModel, single_nn_layer_node_occurrence),
                 (LinearReluLinearModel, two_nn_layer_node_occurrence),
                 (FunctionalLinearReluModel, single_F_layer_node_occurrence),
                 (FunctionalLinearReluLinearModel,
                  two_F_layer_node_occurrence),
                 (ConvModel, single_nn_layer_node_occurrence),
                 (TwoLayerConvModel, two_nn_layer_node_occurrence),
                 (TwoLayerFunctionalConvModel, two_F_layer_node_occurrence),
                 (ConvReluModel, single_nn_layer_node_occurrence),
                 (ConvReluConvModel, two_nn_layer_node_occurrence),
                 (FunctionalConvReluModel, single_F_layer_node_occurrence),
                 (FunctionalConvReluConvModel, two_F_layer_node_occurrence)]

        for (M, node_occurrence) in tests:
            m = M().eval()
            example_inputs = m.get_example_inputs()
            prepared = prepare_fx(
                m,
                specific_qconfig_dict,
                example_inputs=example_inputs,
                _equalization_config=default_equalization_qconfig_dict)
            self.checkGraphModuleNodes(
                prepared, expected_node_occurrence=node_occurrence)
Exemplo n.º 16
0
    def test_q_prep_fx_before_s_prep(self):
        r"""
        This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask between sparse prepare and convert_fx. This also tests the
        automatic fusion that occurs during prepare_fx.
        """
        (
            mod,
            sparsifier,
            _,
        ) = _get_model_and_sparsifier_and_sparse_config()

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qconfig("fbgemm")
        qconfig_mapping = tq.QConfigMapping() \
            .set_module_name("4", qconfig) \
            .set_module_name("5", qconfig)


        mod = prepare_fx(mod, qconfig_mapping, (example,))

        # its absolutely broken by auto fusion in fx
        # but will still work if you put the correct fqn in
        sparse_config = [
            {
                "tensor_fqn": "5.0.weight",
                "sparsity_level": 0.7,
                "sparse_block_shape": (1, 4),
                "zeros_per_block": 4,
            },
            {"tensor_fqn": "0.0.weight"},
        ]
        sparsifier.prepare(mod, config=sparse_config)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        mod(example)
        mod = convert_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU))
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
Exemplo n.º 17
0
    def test_input_weight_equalization_convert(self):
        """ Tests that the modified model for equalization (before quantization)
        returns the same output as the original model
        """

        tests = [(SingleLayerLinearModel, 2), (LinearAddModel, 2),
                 (TwoLayerLinearModel, 2),
                 (SingleLayerFunctionalLinearModel, 2),
                 (FunctionalLinearAddModel, 2),
                 (TwoLayerFunctionalLinearModel, 2), (LinearReluModel, 2),
                 (LinearReluLinearModel, 2), (LinearReluAddModel, 2),
                 (FunctionalLinearReluModel, 2),
                 (FunctionalLinearReluLinearModel, 2), (ConvModel, 4),
                 (TwoLayerConvModel, 4), (SingleLayerFunctionalConvModel, 4),
                 (TwoLayerFunctionalConvModel, 4), (ConvReluModel, 4),
                 (ConvReluConvModel, 4), (ConvReluAddModel, 4),
                 (FunctionalConvReluModel, 4),
                 (FunctionalConvReluConvModel, 4)]

        for (M, ndim) in tests:
            m = M().eval()

            if ndim == 2:
                x = torch.rand((5, 5))
            elif ndim == 4:
                x = torch.rand((16, 3, 224, 224))

            example_inputs = (x, )
            prepared = prepare_fx(
                copy.deepcopy(m),
                specific_qconfig_dict,
                example_inputs=example_inputs,
                _equalization_config=default_equalization_qconfig_dict)
            output = prepared(x)

            convert_ref = _convert_equalization_ref(prepared)
            convert_ref_output = convert_ref(x)

            prepared = prepare_fx(
                m,
                specific_qconfig_dict,
                example_inputs=example_inputs,
                _equalization_config=default_equalization_qconfig_dict)
            prepared(x)
            convert_fx(prepared)  # Check if compile
            self.assertEqual(output, convert_ref_output)
Exemplo n.º 18
0
 def _do_quant_transforms(
     m: torch.nn.Module,
     input_tensor: torch.Tensor,
 ) -> torch.nn.Module:
     # do the quantizaton transforms and save result
     qconfig = torch.quantization.get_default_qconfig('fbgemm')
     mp = quantize_fx.prepare_fx(m, {'': qconfig})
     mp(input_tensor)
     mq = quantize_fx.convert_fx(mp)
     return mq
Exemplo n.º 19
0
    def test_conv_add(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)

            def forward(self, x, y):
                return self.conv(x) + y

        weighted_op_qint8_dtype_config = {
            # optional, input activation dtype
            "input_dtype": torch.qint8,
            # optional, weight dtype
            "weight_dtype": torch.qint8,
            # optional, bias dtype
            "bias_dtype": torch.float,
            # optional, output activation dtype
            "output_dtype": torch.qint8
        }

        conv_add_config = {
            "pattern": (operator.add, torch.nn.Conv2d, MatchAllNode),
            "observation_type":
            ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
            "dtype_configs": [
                weighted_op_qint8_dtype_config,
            ],
            "root_module":
            torch.nn.Conv2d,
            "reference_quantized_module_for_root":
            torch.nn.quantized._reference.Conv2d,
        }

        m = M().eval()
        modified_backend_config_dict = copy.deepcopy(
            self.trt_backend_config_dict)
        modified_backend_config_dict["configs"].insert(0, conv_add_config)
        m = prepare_fx(m, {"": self.qconfig},
                       backend_config_dict=modified_backend_config_dict)
        node_occurrence = {
            ns.call_module(torch.ao.quantization.HistogramObserver): 3,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
        m = _convert_fx_do_not_use(
            m,
            is_reference=True,
            backend_config_dict=modified_backend_config_dict)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 3,
            ns.call_method("dequantize"): 3,
        }
        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
Exemplo n.º 20
0
def default_prepare_for_quant(cfg, model):
    """
    Default implementation of preparing a model for quantization. This function will
    be called to before training if QAT is enabled, or before calibration during PTQ if
    the model is not already quantized.

    NOTE:
        - This is the simplest implementation, most meta-arch needs its own version.
        - For eager model, user should make sure the returned model has Quant/DeQuant
            insert. This can be done by wrapping the model or defining the model with
            quant stubs.
        - QAT/PTQ can be determined by model.training.
        - Currently the input model can be changed inplace since we won't re-use the
            input model.
        - Currently this API doesn't include the final torch.ao.quantization.prepare(_qat)
            call since existing usecases don't have further steps after it.

    Args:
        model (nn.Module): a non-quantized model.
        cfg (CfgNode): config

    Return:
        nn.Module: a ready model for QAT training or PTQ calibration
    """
    qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)

    if cfg.QUANTIZATION.EAGER_MODE:
        model = fuse_utils.fuse_model(
            model,
            is_qat=cfg.QUANTIZATION.QAT.ENABLED,
            inplace=True,
        )

        model.qconfig = qconfig
        # TODO(future diff): move the torch.ao.quantization.prepare(...) call
        # here, to be consistent with the FX branch
    else:  # FX graph mode quantization
        qconfig_dict = {"": qconfig}
        # TODO[quant-example-inputs]: needs follow up to change the api
        example_inputs = (torch.rand(1, 3, 3, 3), )
        if model.training:
            model = prepare_qat_fx(model, qconfig_dict, example_inputs)
        else:
            model = prepare_fx(model, qconfig_dict, example_inputs)

    logger.info("Setup the model with qconfig:\n{}".format(qconfig))

    return model
Exemplo n.º 21
0
    def test_s_prep_q_prep_fx_ref(self):
        r"""
        This checks that the ordering: sparse prepare -> prepare_fx -> convert_to_reference_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask before convert_to_reference_fx.
        """
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config()
        sparsifier.prepare(mod, config=sparse_config)

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qconfig("fbgemm")
        qconfig_mapping = tq.QConfigMapping() \
            .set_module_name("4", qconfig) \
            .set_module_name("5", qconfig)
        mod = prepare_fx(mod, qconfig_mapping, (example,))

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))

        # check that correct observers were inserted and that matching
        # occured successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        mod(example)
        mod = convert_to_reference_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.LinearReLU))
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
        self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.nn.quantized._reference.Linear))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
    def test_conv2d(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        m = M().eval()
        qconfig_dict = {"": default_qconfig, "module_name": [("conv1", None)]}
        m = prepare_fx(m, qconfig_dict)
        data = torch.randn(1, 1, 1, 1)
        m = convert_fx(m)
        # first conv is quantized, second conv is not quantized
        self._compare_script_and_mobile(m, input=data)
Exemplo n.º 23
0
    def test_quantize_model(self, config):
        """
        Test that the model builds using a config using either model_params or
        model_name and calls fx graph mode quantization apis
        """
        if get_torch_version() < [1, 8]:
            self.skipTest(
                "FX Graph Modee Quantization is only availablee from 1.8")
        if get_torch_version() >= [1, 11]:
            import torch.ao.quantization as tq
            from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
        else:
            import torch.quantization as tq
            from torch.quantization.quantize_fx import convert_fx, prepare_fx

        model = build_model(config)
        assert isinstance(model, RegNet)
        model.eval()
        model.stem = prepare_fx(model.stem, {"": tq.default_qconfig})
        model.stem = convert_fx(model.stem)
Exemplo n.º 24
0
    def test_cat(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.cat([x, x], 1)

        m = M().eval()
        prepared = prepare_fx(m, {"": self.qconfig},
                              backend_config_dict=self.trt_backend_config_dict)
        self.assertTrue(len(dict(prepared.named_children())) == 1)
        quantized = _convert_fx_do_not_use(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 2,
            ns.call_function(torch.cat): 1,
            ns.call_method("dequantize"): 2,
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
    def test_submodule(self):
        # test quantizing complete module, submodule and linear layer
        configs = [
            {},
            {
                "module_name": [("subm", None)]
            },
            {
                "module_name": [("fc", None)]
            },
        ]
        for config in configs:
            model = LinearModelWithSubmodule().eval()
            qconfig_dict = {
                "": torch.ao.quantization.get_default_qconfig("qnnpack"),
                **config,
            }
            model = prepare_fx(model, qconfig_dict)
            quant = convert_fx(model)

            x = torch.randn(5, 5)
            self._compare_script_and_mobile(quant, input=x)
Exemplo n.º 26
0
    def test_input_weight_equalization_weights_bias(self):
        """ After applying the equalization functions check if the weights and
        biases are as expected
        """

        tests = [
            SingleLayerLinearModel, TwoLayerLinearModel,
            SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel
        ]

        x = torch.rand((5, 5))
        for M in tests:
            m = M().eval()
            exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
            exp_weights, exp_bias = self.get_expected_weights_bias(
                m,
                x.detach().numpy(), exp_eq_scales)

            example_inputs = (x, )
            prepared = prepare_fx(
                m,
                specific_qconfig_dict,
                example_inputs=example_inputs,
                _equalization_config=default_equalization_qconfig_dict)
            prepared(x)
            convert_ref = _convert_equalization_ref(prepared)
            convert_ref(x)

            modules = dict(convert_ref.named_modules(remove_duplicate=False))
            counter = 0
            for node in convert_ref.graph.nodes:
                if node.op == 'call_module' and isinstance(
                        modules[str(node.target)], nn.Linear):
                    self.assertEqual(modules[str(node.target)].weight,
                                     exp_weights[counter])
                    self.assertEqual(modules[str(node.target)].bias,
                                     exp_bias[counter])
                    counter += 1
Exemplo n.º 27
0
    def test_nested_detection_case(self):
        class SingleLinear(torch.nn.Module):
            def __init__(self):
                super(SingleLinear, self).__init__()
                self.linear = torch.nn.Linear(3, 3)

            def forward(self, x):
                x = self.linear(x)
                return x

        class TwoBlockNet(torch.nn.Module):
            def __init__(self):
                super(TwoBlockNet, self).__init__()
                self.block1 = SingleLinear()
                self.block2 = SingleLinear()

            def forward(self, x):
                x = self.block1(x)
                y = self.block2(x)
                z = x + y
                z = F.relu(z)
                return z

        # create model, example input, and qconfig mapping
        torch.backends.quantized.engine = "fbgemm"
        model = TwoBlockNet()
        example_input = torch.randint(-10, 0, (1, 3, 3, 3))
        example_input = example_input.to(torch.float)
        q_config_mapping = QConfigMapping()
        q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm"))

        # prep model and select observer
        model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
        obs_ctr = ModelReportObserver

        # find layer to attach to and store
        linear_fqn = "block2.linear"  # fqn of target linear

        target_linear = None
        for node in model_prep.graph.nodes:
            if node.target == linear_fqn:
                target_linear = node
                break

        # insert into both module and graph pre and post

        # set up to insert before target_linear (pre_observer)
        with model_prep.graph.inserting_before(target_linear):
            obs_to_insert = obs_ctr()
            pre_obs_fqn = linear_fqn + ".model_report_pre_observer"
            model_prep.add_submodule(pre_obs_fqn, obs_to_insert)
            model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args)

        # set up and insert after the target_linear (post_observer)
        with model_prep.graph.inserting_after(target_linear):
            obs_to_insert = obs_ctr()
            post_obs_fqn = linear_fqn + ".model_report_post_observer"
            model_prep.add_submodule(post_obs_fqn, obs_to_insert)
            model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,))

        # need to recompile module after submodule added and pass input through
        model_prep.recompile()

        num_iterations = 10
        for i in range(num_iterations):
            if i % 2 == 0:
                example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float)
            else:
                example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float)
            model_prep(example_input)

        # run it through the dynamic vs static detector
        dynam_vs_stat_str, dynam_vs_stat_dict = _detect_dynamic_vs_static(model_prep, tolerance=0.5)

        # one of the stats should be stationary, and the other non-stationary
        # as a result, dynamic should be recommended
        data_dist_info = [
            dynam_vs_stat_dict[linear_fqn]["pre_observer_data_dist"],
            dynam_vs_stat_dict[linear_fqn]["post_observer_data_dist"],
        ]

        self.assertTrue("stationary" in data_dist_info)
        self.assertTrue("non-stationary" in data_dist_info)
        self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"])
Exemplo n.º 28
0
    def test_selective_equalization(self):
        """ Tests that we are able to run numeric suite on the equalized model
        and construct a valid equalization_qconfig_dict equalizing only the top
        4 layers with the highest quantization errors.
        """

        torch.manual_seed(1)

        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5))
                self.top = torch.nn.Sequential(torch.nn.Linear(5, 5))

            def forward(self, x):
                x = self.bot(x)
                x = torch.add(x, 5)
                x = self.top(x)
                return x

        float_model = M().eval()
        # Hard coded so that the top layer has a higher quantization error
        x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957],
                          [0.8373, 0.8851, 0.8229, 0.0212, 0.8987],
                          [0.9077, 0.7538, 0.4530, 0.5772, 0.1376],
                          [0.0690, 0.9002, 0.7998, 0.2768, 0.8985],
                          [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]])

        # Quantize the float model
        prepared_model = prepare_fx(copy.deepcopy(float_model),
                                    specific_qconfig_dict)
        prepared_model(x)
        quantized_model = convert_fx(copy.deepcopy(prepared_model))

        # Get the SQNR between the float and quantized model
        layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model),
                                                 quantized_model, x)

        # Construct the equalization_qconfig_dict equalizing layers with the highest
        # quantization errors
        selective_equalization_qconfig_dict = get_equalization_qconfig_dict(
            layer_to_sqnr_dict, 1)

        # Create the selectively equalized model
        prepared_model = prepare_fx(
            copy.deepcopy(float_model),
            specific_qconfig_dict,
            equalization_qconfig_dict=selective_equalization_qconfig_dict,
        )
        prepared_model(x)
        equalized_model = convert_fx(prepared_model)

        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        # Check the order of nodes in the graph
        self.checkGraphModuleNodes(equalized_model,
                                   expected_node_list=node_list)
Exemplo n.º 29
0
    def test_input_weight_equalization_graphs(self):
        """ Tests that the modified model for equalization has the same graph
        structure as the model without equalization (before and after
        quantization).
        """

        linear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        linearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_method('dequantize')
        ]

        linearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_method('dequantize')
        ]

        functionalLinearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        conv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        conv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        convRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_method('dequantize')
        ]

        convReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConvRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_method('dequantize')
        ]

        functionalConvReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        tests = [
            (SingleLayerLinearModel, linear_node_list),
            (LinearAddModel, linearAdd_node_list),
            (TwoLayerLinearModel, linear2_node_list),
            (SingleLayerFunctionalLinearModel, functionalLinear_node_list),
            (FunctionalLinearAddModel, functionalLinearAdd_node_list),
            (TwoLayerFunctionalLinearModel, functionalLinear2_node_list),
            (LinearReluModel, linearRelu_node_list),
            (LinearReluLinearModel, linearReluLinear_node_list),
            (FunctionalLinearReluModel, functionalLinearRelu_node_list),
            (FunctionalLinearReluLinearModel,
             functionalLinearReluLinear_node_list),
            (ConvModel, conv_node_list), (TwoLayerConvModel, conv2_node_list),
            (SingleLayerFunctionalConvModel, functionalConv_node_list),
            (TwoLayerFunctionalConvModel, functionalConv2_node_list),
            (ConvReluModel, convRelu_node_list),
            (ConvReluConvModel, convReluConv_node_list),
            (FunctionalConvReluModel, functionalConvRelu_node_list),
            (FunctionalConvReluConvModel, functionalConvReluConv_node_list)
        ]

        for (M, node_list) in tests:
            m = M().eval()
            prepared = prepare_fx(
                m,
                specific_qconfig_dict,
                equalization_qconfig_dict=default_equalization_qconfig_dict)
            equalized_quantized_model = convert_fx(prepared)

            # Check the order of nodes in the graph
            self.checkGraphModuleNodes(equalized_quantized_model,
                                       expected_node_list=node_list)
            def prepare(self, model, configs, attrs):
                example_inputs = (torch.randn(1, 2), )
                model.another_layer = prepare_fx(model.another_layer,
                                                 configs[""], example_inputs)

                return model