def test_input_weight_equalization_results(self): """ Tests that for small models, the results of quantized models that have been equalized are very close to models that have not been equalized. """ tests = [ SingleLayerLinearModel, TwoLayerLinearModel, LinearAddModel, SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel ] x = torch.rand((5, 5)) for M in tests: m = M().eval() # No equalization prepared = prepare_fx(copy.deepcopy(m), specific_qconfig_dict, equalization_qconfig_dict={}) prepared(x) quantized = convert_fx(prepared) # Check if compile quantized_output = quantized(x) # With equalization prepared = prepare_fx( copy.deepcopy(m), specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) prepared(x) equalized_and_quantized = convert_fx(prepared) # Check if compile equalized_and_quantized_output = equalized_and_quantized(x) self.assertEqual(quantized_output, equalized_and_quantized_output, rtol=1e-5, atol=0.1)
def test_input_weight_equalization_branching(self): """ Tests that graphs containing branches are prepared correctly. Specifically, equalization observers should not be inserted in front of branches in which both initial layers in the branches plan to be quantized. """ # Tests that we do not add an equalization observer due to both initial # nodes in the branch containing layers that need to be equalized. # Note that this should print out 2 warning messages for not being able # to equalize layers linear1 and linear1 because it is part of a branch class TestBranchingWithoutEqualizationModel(nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = nn.Linear(5, 5) self.linear2 = nn.Linear(5, 5) def forward(self, x): y = self.linear1(x) z = self.linear2(x) return torch.add(y, z) no_eq_branching_node_occurrence = { ns.call_module(_InputEqualizationObserver): 0, ns.call_module(MinMaxObserver): 3, } m = TestBranchingWithoutEqualizationModel().eval() prepared = prepare_fx( m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) self.checkGraphModuleNodes( prepared, expected_node_occurrence=no_eq_branching_node_occurrence) # Tests that we will add an equalization observer because there is only # one initial node in the branch that needs to be equalized class TestBranchingWithEqualizationModel(nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = nn.Linear(5, 5) def forward(self, x): y = self.linear1(x) z = torch.add(x, 5) return torch.add(y, z) eq_branching_node_occurrence = { ns.call_module(_InputEqualizationObserver): 1, ns.call_module(MinMaxObserver): 2, } m = TestBranchingWithEqualizationModel().eval() prepared = prepare_fx( m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) self.checkGraphModuleNodes( prepared, expected_node_occurrence=eq_branching_node_occurrence)
def _test_quantize_model(self, model_config): if get_torch_version() >= [1, 11]: import torch.ao.quantization as tq from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx else: import torch.quantization as tq from torch.quantization.quantize_fx import convert_fx, prepare_fx # quantize model model = build_model(model_config) model.eval() input = torch.ones([1, 3, 32, 32]) heads = model.get_heads() # since prepare changes the code of ClassyBlock we need to clear head first # and reattach it later to avoid caching model.clear_heads() prepare_custom_config_dict = {} head_path_from_blocks = [ _find_block_full_path(model.features, block_name) for block_name in heads.keys() ] # we need to keep the modules used in head standalone since # it will be accessed with path name directly in execution prepare_custom_config_dict["standalone_module_name"] = [( head, { "": tq.default_qconfig }, { "input_quantized_idxs": [0], "output_quantized_idxs": [] }, None, ) for head in head_path_from_blocks] model.initial_block = prepare_fx(model.initial_block, {"": tq.default_qconfig}) model.features = prepare_fx( model.features, {"": tq.default_qconfig}, prepare_custom_config_dict, ) model.set_heads(heads) # calibration model(input) heads = model.get_heads() model.clear_heads() model.initial_block = convert_fx(model.initial_block) model.features = convert_fx(model.features) model.set_heads(heads) output = model(input) self.assertEqual(output.size(), (1, 1000))
def test_linear(self): class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) m = LinearModule().eval() qconfig = torch.quantization.QConfig( activation=torch.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8), weight=torch.quantization.default_weight_observer) prepared = prepare_fx( m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict()) # calibration prepared(linear_module_input) quantized = convert_fx(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) # lower to trt trt_mod = lower_to_trt(quantized, linear_module_input, [((1, *linear_module_input.shape[1:]), (5, *linear_module_input.shape[1:]), (10, *linear_module_input.shape[1:]))]) # make sure it runs trt_mod(linear_module_input.cuda())
def test_addmm(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.weight = torch.randn(5, 5) self.bias = torch.randn(5) def forward(self, x): return torch.addmm(self.bias, x, self.weight) m = M().eval() prepared = prepare_fx( m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) node_occurrence = { # weight ns.call_module(torch.ao.quantization.MinMaxObserver): 1, # activation ns.call_module(torch.ao.quantization.HistogramObserver): 2, } self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) quantized = _convert_fx_do_not_use( prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict) node_occurrence = { # input activation, output activation and weight ns.call_function(torch.quantize_per_tensor): 3, ns.call_function(torch.addmm): 1, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_ops(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.linear = torch.nn.Linear(5, 5) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv(x) x = self.linear(x) x = x + 3 x = self.relu(x) x = x + 6 return x m = M().eval() m = prepare_fx(m, {"": default_qconfig}) m = _convert_fx_do_not_use(m, is_reference=True) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 5, ns.call_method("dequantize"): 5, ns.call_module(torch.nn.quantized._reference.Linear): 1, ns.call_module(torch.nn.quantized._reference.Conv2d): 1, } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence)
def build_int8_trt(rn18): rn18 = copy.deepcopy(rn18) data = torch.randn(1, 3, 224, 224) # data = torch.randn(1, 32) # data = torch.randn(1, 64, 10, 10) # TensorRT only supports symmetric quantization qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 ), # weight=torch.ao.quantization.default_weight_observer # uncomment to check per channel quant works weight=torch.quantization.default_per_channel_weight_observer ) prepared = prepare_fx(rn18, {"": qconfig}) for _ in range(10): prepared(data) quantized_rn18 = convert_fx(prepared, is_reference=True) ref_res = quantized_rn18(data) print("quantized model:", quantized_rn18) quantized_rn18 = acc_tracer.trace(quantized_rn18, [data]) # type: ignore[assignment] interp = TRTInterpreter( quantized_rn18, [InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float, shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)], explicit_batch_dimension=True, explicit_precision=True, logger_level=trt.Logger.VERBOSE) interpreter_result = interp.run(fp16_mode=False, int8_mode=True) trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names) trt_res = trt_mod(data.cuda()) print("explicit quant result diff max", torch.max(ref_res - trt_res.cpu())) return trt_mod
def test_unsupported_qconfig(self): """ Check that we won't quantize the model if the qconfig is not supported """ class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) m = LinearModule().eval() trt_unsupported_qconfig = default_qconfig prepared = prepare_fx(m, {"": trt_unsupported_qconfig}, backend_config_dict=self.trt_backend_config_dict) # calibration prepared(linear_module_input) quantized = _convert_fx_do_not_use(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 0, ns.call_method("dequantize"): 0, ns.call_module(torch.nn.Linear): 1, ns.call_module(torch.nn.quantized._reference.Linear): 0, } # check model is not quantized self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_input_weight_equalization_equalization_scales(self): """ After applying the equalization functions, check if the equalization scales are the expected values """ tests = [ SingleLayerLinearModel, TwoLayerLinearModel, SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel ] x = torch.rand((5, 5)) for M in tests: m = M().eval() exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) prepared = prepare_fx( m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) counter = 0 for node in convert_ref.graph.nodes: if 'equalization_scale' in node.name and node.op == 'get_attr': self.assertEqual( convert_ref.get_buffer(str(node.target)).reshape(-1), exp_eq_scales[counter]) counter += 1
def test_conv(self): class Conv2d(torch.nn.Module): def __init__(self, *args): super().__init__() self.conv = torch.nn.Conv2d(*args) def forward(self, x): return self.conv(x) conv2d_input = torch.rand(1, 3, 224, 224) conv2d_module_args = (3, 3, 3) m = Conv2d(*conv2d_module_args).eval() qconfig = torch.quantization.QConfig( activation=torch.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8), weight=torch.quantization.default_weight_observer) prepared = prepare_fx( m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict()) # calibration prepared(conv2d_input) quantized = convert_fx(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) # lower to trt trt_mod = lower_to_trt(quantized, conv2d_input, [((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))]) # make sure it runs trt_mod(conv2d_input.cuda())
def test_embedding(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) def forward(self, indices): return self.emb(indices) model = M().eval() indices = torch.randint(low=0, high=10, size=(20, )) quantized_node = ns.call_module(nnq.Embedding) configs = [ (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)), (None, ns.call_module(nn.Embedding)), (default_qconfig, ns.call_module(nn.Embedding)), ] for qconfig, node in configs: qconfig_dict = {"": qconfig} m = prepare_fx(model, qconfig_dict) m = convert_fx(m) self._compare_script_and_mobile(m, input=indices)
def test_input_weight_equalization_activation_values(self): """ After applying the equalization functions check if the input observer's min/max values are as expected """ tests = [ SingleLayerLinearModel, TwoLayerLinearModel, SingleLayerFunctionalLinearModel ] x = torch.rand((5, 5)) torch.manual_seed(0) for M in tests: m = M().eval() exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) exp_weights, exp_bias = self.get_expected_weights_bias( m, x.detach().numpy(), exp_eq_scales) exp_inp_act_vals = self.get_expected_inp_act_vals( m, x, exp_eq_scales, exp_weights, exp_bias) exp_weight_act_vals = self.get_expected_weight_act_vals( exp_weights) example_inputs = (x, ) prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, _equalization_config=default_equalization_qconfig_dict) prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) modules = dict(convert_ref.named_modules(remove_duplicate=False)) inp_counter = 0 weight_counter = 0 for node in convert_ref.graph.nodes: users = list(node.users) if node.op == 'call_module' and isinstance( modules[str(node.target)], MinMaxObserver): if len(users) == 1 and users[ 0].target == torch.nn.functional.linear and users[ 0].args[1] == node: # Check min/max values of weight activation layers exp_min_val, exp_max_val = exp_weight_act_vals[ weight_counter] self.assertEqual(modules[str(node.target)].min_val, exp_min_val) self.assertEqual(modules[str(node.target)].max_val, exp_max_val) weight_counter += 1 else: # Check min/max values of input activation layers exp_min_val, exp_max_val = exp_inp_act_vals[ inp_counter] self.assertEqual(modules[str(node.target)].min_val, exp_min_val) self.assertEqual(modules[str(node.target)].max_val, exp_max_val) inp_counter += 1
def build_int8_trt_implicit_quant(rn18): rn18 = copy.deepcopy(rn18) data = torch.randn(1, 3, 224, 224) # Quantization qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, reduce_range=True), weight=torch.ao.quantization.default_per_channel_weight_observer) prepared = prepare_fx(rn18, {"": qconfig}) for _ in range(10): prepared(data) quantized_rn18 = convert_fx(prepared) ref_res = quantized_rn18(data) # Build trt int8 model traced_rn18 = torch.fx.symbolic_trace(quantized_rn18) shape_prop.ShapeProp(traced_rn18).propagate(data) traced_rn18 = NormalizeArgs(traced_rn18).transform() interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE) engine, input_names, output_names = interp.run( fp16_mode=False, int8_mode=True, strict_type_constraints=True) trt_mod = TRTModule(engine, input_names, output_names) trt_res = trt_mod(data.cuda()) print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu())) return trt_mod
def _test_module(self, m, inputs, shape_ranges, no_prepare=None, no_convert=None): """ Args: m: the float module we want to test inputs: list of inputs for the module shape_ranges: a list of shape_range, where every shape_range is a tuple of three tuples ((min_input_shape), (optimized_input_shape), (max_input_shape)). Each shape_range is used to populate a TensorRT optimization profile. e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize for (25, 224) because it's the most common input shape, then we set shape_ranges to ((1, 224), (25, 225), (100, 224)) no_prepare: node occurrence after prepare no_convert: node occurrence after convert """ m = m.eval() prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_prepare) # calibration prepared(*inputs) quantized = _convert_fx_do_not_use(prepared, is_reference=True) self.checkGraphModuleNodes(quantized, expected_node_occurrence=no_convert) # lower to trt trt_mod = lower_to_trt(quantized, inputs, shape_ranges) inputs_cuda = [i.cuda() for i in inputs] # make sure it runs trt_mod(*inputs_cuda)
def test_input_weight_equalization_prepare(self): """ Tests that graphs created after prepare_fx is as expected """ single_nn_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 1, ns.call_module(MinMaxObserver): 2, } two_nn_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 2, ns.call_module(MinMaxObserver): 3, } single_F_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 1, ns.call_module(_WeightEqualizationObserver): 1, ns.call_module(MinMaxObserver): 3, } two_F_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 2, ns.call_module(_WeightEqualizationObserver): 2, ns.call_module(MinMaxObserver): 5, } fp_F_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 2, ns.call_module(_WeightEqualizationObserver): 2, ns.call_module(MinMaxObserver): 6, } tests = [(SingleLayerLinearModel, single_nn_layer_node_occurrence), (TwoLayerLinearModel, two_nn_layer_node_occurrence), (TwoLayerFunctionalLinearModel, two_F_layer_node_occurrence), (FunctionalLinearAddModel, fp_F_layer_node_occurrence), (LinearReluModel, single_nn_layer_node_occurrence), (LinearReluLinearModel, two_nn_layer_node_occurrence), (FunctionalLinearReluModel, single_F_layer_node_occurrence), (FunctionalLinearReluLinearModel, two_F_layer_node_occurrence), (ConvModel, single_nn_layer_node_occurrence), (TwoLayerConvModel, two_nn_layer_node_occurrence), (TwoLayerFunctionalConvModel, two_F_layer_node_occurrence), (ConvReluModel, single_nn_layer_node_occurrence), (ConvReluConvModel, two_nn_layer_node_occurrence), (FunctionalConvReluModel, single_F_layer_node_occurrence), (FunctionalConvReluConvModel, two_F_layer_node_occurrence)] for (M, node_occurrence) in tests: m = M().eval() example_inputs = m.get_example_inputs() prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, _equalization_config=default_equalization_qconfig_dict) self.checkGraphModuleNodes( prepared, expected_node_occurrence=node_occurrence)
def test_q_prep_fx_before_s_prep(self): r""" This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx compose cleanly without issue and that the final result is sparsified without having to call squash mask between sparse prepare and convert_fx. This also tests the automatic fusion that occurs during prepare_fx. """ ( mod, sparsifier, _, ) = _get_model_and_sparsifier_and_sparse_config() example = torch.randn(1, 4, 4, 4) qconfig = tq.get_default_qconfig("fbgemm") qconfig_mapping = tq.QConfigMapping() \ .set_module_name("4", qconfig) \ .set_module_name("5", qconfig) mod = prepare_fx(mod, qconfig_mapping, (example,)) # its absolutely broken by auto fusion in fx # but will still work if you put the correct fqn in sparse_config = [ { "tensor_fqn": "5.0.weight", "sparsity_level": 0.7, "sparse_block_shape": (1, 4), "zeros_per_block": 4, }, {"tensor_fqn": "0.0.weight"}, ] sparsifier.prepare(mod, config=sparse_config) # check that correct modules had parametrizations added and # that none were lost during prepare self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) # check that correct observers were inserted and that matching # occured successfully self.assertTrue(_module_has_activation_post_process(mod, "5")) sparsifier.step() sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) mod(example) mod = convert_fx(mod) # check that final module is the expected quantized module and that the model runs self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU)) self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0]) self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) self.assertGreaterAlmostEqual( sparsity_level, sparse_config[0]["sparsity_level"] ) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
def test_input_weight_equalization_convert(self): """ Tests that the modified model for equalization (before quantization) returns the same output as the original model """ tests = [(SingleLayerLinearModel, 2), (LinearAddModel, 2), (TwoLayerLinearModel, 2), (SingleLayerFunctionalLinearModel, 2), (FunctionalLinearAddModel, 2), (TwoLayerFunctionalLinearModel, 2), (LinearReluModel, 2), (LinearReluLinearModel, 2), (LinearReluAddModel, 2), (FunctionalLinearReluModel, 2), (FunctionalLinearReluLinearModel, 2), (ConvModel, 4), (TwoLayerConvModel, 4), (SingleLayerFunctionalConvModel, 4), (TwoLayerFunctionalConvModel, 4), (ConvReluModel, 4), (ConvReluConvModel, 4), (ConvReluAddModel, 4), (FunctionalConvReluModel, 4), (FunctionalConvReluConvModel, 4)] for (M, ndim) in tests: m = M().eval() if ndim == 2: x = torch.rand((5, 5)) elif ndim == 4: x = torch.rand((16, 3, 224, 224)) example_inputs = (x, ) prepared = prepare_fx( copy.deepcopy(m), specific_qconfig_dict, example_inputs=example_inputs, _equalization_config=default_equalization_qconfig_dict) output = prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref_output = convert_ref(x) prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, _equalization_config=default_equalization_qconfig_dict) prepared(x) convert_fx(prepared) # Check if compile self.assertEqual(output, convert_ref_output)
def _do_quant_transforms( m: torch.nn.Module, input_tensor: torch.Tensor, ) -> torch.nn.Module: # do the quantizaton transforms and save result qconfig = torch.quantization.get_default_qconfig('fbgemm') mp = quantize_fx.prepare_fx(m, {'': qconfig}) mp(input_tensor) mq = quantize_fx.convert_fx(mp) return mq
def test_conv_add(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x, y): return self.conv(x) + y weighted_op_qint8_dtype_config = { # optional, input activation dtype "input_dtype": torch.qint8, # optional, weight dtype "weight_dtype": torch.qint8, # optional, bias dtype "bias_dtype": torch.float, # optional, output activation dtype "output_dtype": torch.qint8 } conv_add_config = { "pattern": (operator.add, torch.nn.Conv2d, MatchAllNode), "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_qint8_dtype_config, ], "root_module": torch.nn.Conv2d, "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } m = M().eval() modified_backend_config_dict = copy.deepcopy( self.trt_backend_config_dict) modified_backend_config_dict["configs"].insert(0, conv_add_config) m = prepare_fx(m, {"": self.qconfig}, backend_config_dict=modified_backend_config_dict) node_occurrence = { ns.call_module(torch.ao.quantization.HistogramObserver): 3, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) m = _convert_fx_do_not_use( m, is_reference=True, backend_config_dict=modified_backend_config_dict) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
def default_prepare_for_quant(cfg, model): """ Default implementation of preparing a model for quantization. This function will be called to before training if QAT is enabled, or before calibration during PTQ if the model is not already quantized. NOTE: - This is the simplest implementation, most meta-arch needs its own version. - For eager model, user should make sure the returned model has Quant/DeQuant insert. This can be done by wrapping the model or defining the model with quant stubs. - QAT/PTQ can be determined by model.training. - Currently the input model can be changed inplace since we won't re-use the input model. - Currently this API doesn't include the final torch.ao.quantization.prepare(_qat) call since existing usecases don't have further steps after it. Args: model (nn.Module): a non-quantized model. cfg (CfgNode): config Return: nn.Module: a ready model for QAT training or PTQ calibration """ qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training) if cfg.QUANTIZATION.EAGER_MODE: model = fuse_utils.fuse_model( model, is_qat=cfg.QUANTIZATION.QAT.ENABLED, inplace=True, ) model.qconfig = qconfig # TODO(future diff): move the torch.ao.quantization.prepare(...) call # here, to be consistent with the FX branch else: # FX graph mode quantization qconfig_dict = {"": qconfig} # TODO[quant-example-inputs]: needs follow up to change the api example_inputs = (torch.rand(1, 3, 3, 3), ) if model.training: model = prepare_qat_fx(model, qconfig_dict, example_inputs) else: model = prepare_fx(model, qconfig_dict, example_inputs) logger.info("Setup the model with qconfig:\n{}".format(qconfig)) return model
def test_s_prep_q_prep_fx_ref(self): r""" This checks that the ordering: sparse prepare -> prepare_fx -> convert_to_reference_fx compose cleanly without issue and that the final result is sparsified without having to call squash mask before convert_to_reference_fx. """ ( mod, sparsifier, sparse_config, ) = _get_model_and_sparsifier_and_sparse_config() sparsifier.prepare(mod, config=sparse_config) example = torch.randn(1, 4, 4, 4) qconfig = tq.get_default_qconfig("fbgemm") qconfig_mapping = tq.QConfigMapping() \ .set_module_name("4", qconfig) \ .set_module_name("5", qconfig) mod = prepare_fx(mod, qconfig_mapping, (example,)) # check that correct modules had parametrizations added and # that none were lost during prepare self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) # check that correct observers were inserted and that matching # occured successfully self.assertTrue(_module_has_activation_post_process(mod, "5")) sparsifier.step() sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) mod(example) mod = convert_to_reference_fx(mod) # check that final module is the expected quantized module and that the model runs self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.LinearReLU)) self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.nn.quantized._reference.Linear)) # check that module was actually sparsified cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) self.assertGreaterAlmostEqual( sparsity_level, sparse_config[0]["sparsity_level"] ) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
def test_conv2d(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv1 = nn.Conv2d(1, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x m = M().eval() qconfig_dict = {"": default_qconfig, "module_name": [("conv1", None)]} m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) m = convert_fx(m) # first conv is quantized, second conv is not quantized self._compare_script_and_mobile(m, input=data)
def test_quantize_model(self, config): """ Test that the model builds using a config using either model_params or model_name and calls fx graph mode quantization apis """ if get_torch_version() < [1, 8]: self.skipTest( "FX Graph Modee Quantization is only availablee from 1.8") if get_torch_version() >= [1, 11]: import torch.ao.quantization as tq from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx else: import torch.quantization as tq from torch.quantization.quantize_fx import convert_fx, prepare_fx model = build_model(config) assert isinstance(model, RegNet) model.eval() model.stem = prepare_fx(model.stem, {"": tq.default_qconfig}) model.stem = convert_fx(model.stem)
def test_cat(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.cat([x, x], 1) m = M().eval() prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) self.assertTrue(len(dict(prepared.named_children())) == 1) quantized = _convert_fx_do_not_use(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_function(torch.cat): 1, ns.call_method("dequantize"): 2, } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_submodule(self): # test quantizing complete module, submodule and linear layer configs = [ {}, { "module_name": [("subm", None)] }, { "module_name": [("fc", None)] }, ] for config in configs: model = LinearModelWithSubmodule().eval() qconfig_dict = { "": torch.ao.quantization.get_default_qconfig("qnnpack"), **config, } model = prepare_fx(model, qconfig_dict) quant = convert_fx(model) x = torch.randn(5, 5) self._compare_script_and_mobile(quant, input=x)
def test_input_weight_equalization_weights_bias(self): """ After applying the equalization functions check if the weights and biases are as expected """ tests = [ SingleLayerLinearModel, TwoLayerLinearModel, SingleLayerFunctionalLinearModel, TwoLayerFunctionalLinearModel ] x = torch.rand((5, 5)) for M in tests: m = M().eval() exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) exp_weights, exp_bias = self.get_expected_weights_bias( m, x.detach().numpy(), exp_eq_scales) example_inputs = (x, ) prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, _equalization_config=default_equalization_qconfig_dict) prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) modules = dict(convert_ref.named_modules(remove_duplicate=False)) counter = 0 for node in convert_ref.graph.nodes: if node.op == 'call_module' and isinstance( modules[str(node.target)], nn.Linear): self.assertEqual(modules[str(node.target)].weight, exp_weights[counter]) self.assertEqual(modules[str(node.target)].bias, exp_bias[counter]) counter += 1
def test_nested_detection_case(self): class SingleLinear(torch.nn.Module): def __init__(self): super(SingleLinear, self).__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, x): x = self.linear(x) return x class TwoBlockNet(torch.nn.Module): def __init__(self): super(TwoBlockNet, self).__init__() self.block1 = SingleLinear() self.block2 = SingleLinear() def forward(self, x): x = self.block1(x) y = self.block2(x) z = x + y z = F.relu(z) return z # create model, example input, and qconfig mapping torch.backends.quantized.engine = "fbgemm" model = TwoBlockNet() example_input = torch.randint(-10, 0, (1, 3, 3, 3)) example_input = example_input.to(torch.float) q_config_mapping = QConfigMapping() q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm")) # prep model and select observer model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input) obs_ctr = ModelReportObserver # find layer to attach to and store linear_fqn = "block2.linear" # fqn of target linear target_linear = None for node in model_prep.graph.nodes: if node.target == linear_fqn: target_linear = node break # insert into both module and graph pre and post # set up to insert before target_linear (pre_observer) with model_prep.graph.inserting_before(target_linear): obs_to_insert = obs_ctr() pre_obs_fqn = linear_fqn + ".model_report_pre_observer" model_prep.add_submodule(pre_obs_fqn, obs_to_insert) model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args) # set up and insert after the target_linear (post_observer) with model_prep.graph.inserting_after(target_linear): obs_to_insert = obs_ctr() post_obs_fqn = linear_fqn + ".model_report_post_observer" model_prep.add_submodule(post_obs_fqn, obs_to_insert) model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,)) # need to recompile module after submodule added and pass input through model_prep.recompile() num_iterations = 10 for i in range(num_iterations): if i % 2 == 0: example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float) else: example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float) model_prep(example_input) # run it through the dynamic vs static detector dynam_vs_stat_str, dynam_vs_stat_dict = _detect_dynamic_vs_static(model_prep, tolerance=0.5) # one of the stats should be stationary, and the other non-stationary # as a result, dynamic should be recommended data_dist_info = [ dynam_vs_stat_dict[linear_fqn]["pre_observer_data_dist"], dynam_vs_stat_dict[linear_fqn]["post_observer_data_dist"], ] self.assertTrue("stationary" in data_dist_info) self.assertTrue("non-stationary" in data_dist_info) self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"])
def test_selective_equalization(self): """ Tests that we are able to run numeric suite on the equalized model and construct a valid equalization_qconfig_dict equalizing only the top 4 layers with the highest quantization errors. """ torch.manual_seed(1) class M(nn.Module): def __init__(self): super().__init__() self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5)) self.top = torch.nn.Sequential(torch.nn.Linear(5, 5)) def forward(self, x): x = self.bot(x) x = torch.add(x, 5) x = self.top(x) return x float_model = M().eval() # Hard coded so that the top layer has a higher quantization error x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957], [0.8373, 0.8851, 0.8229, 0.0212, 0.8987], [0.9077, 0.7538, 0.4530, 0.5772, 0.1376], [0.0690, 0.9002, 0.7998, 0.2768, 0.8985], [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]]) # Quantize the float model prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict) prepared_model(x) quantized_model = convert_fx(copy.deepcopy(prepared_model)) # Get the SQNR between the float and quantized model layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model), quantized_model, x) # Construct the equalization_qconfig_dict equalizing layers with the highest # quantization errors selective_equalization_qconfig_dict = get_equalization_qconfig_dict( layer_to_sqnr_dict, 1) # Create the selectively equalized model prepared_model = prepare_fx( copy.deepcopy(float_model), specific_qconfig_dict, equalization_qconfig_dict=selective_equalization_qconfig_dict, ) prepared_model(x) equalized_model = convert_fx(prepared_model) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] # Check the order of nodes in the graph self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
def test_input_weight_equalization_graphs(self): """ Tests that the modified model for equalization has the same graph structure as the model without equalization (before and after quantization). """ linear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] linearAdd_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] linear2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] functionalLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] functionalLinearAdd_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] functionalLinear2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] linearRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_method('dequantize') ] linearReluLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] functionalLinearRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear_relu), ns.call_method('dequantize') ] functionalLinearReluLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear_relu), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] conv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] conv2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] functionalConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] functionalConv2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] convRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.ConvReLU2d), ns.call_method('dequantize') ] convReluConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.ConvReLU2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] functionalConvRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d_relu), ns.call_method('dequantize') ] functionalConvReluConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d_relu), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] tests = [ (SingleLayerLinearModel, linear_node_list), (LinearAddModel, linearAdd_node_list), (TwoLayerLinearModel, linear2_node_list), (SingleLayerFunctionalLinearModel, functionalLinear_node_list), (FunctionalLinearAddModel, functionalLinearAdd_node_list), (TwoLayerFunctionalLinearModel, functionalLinear2_node_list), (LinearReluModel, linearRelu_node_list), (LinearReluLinearModel, linearReluLinear_node_list), (FunctionalLinearReluModel, functionalLinearRelu_node_list), (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list), (ConvModel, conv_node_list), (TwoLayerConvModel, conv2_node_list), (SingleLayerFunctionalConvModel, functionalConv_node_list), (TwoLayerFunctionalConvModel, functionalConv2_node_list), (ConvReluModel, convRelu_node_list), (ConvReluConvModel, convReluConv_node_list), (FunctionalConvReluModel, functionalConvRelu_node_list), (FunctionalConvReluConvModel, functionalConvReluConv_node_list) ] for (M, node_list) in tests: m = M().eval() prepared = prepare_fx( m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) equalized_quantized_model = convert_fx(prepared) # Check the order of nodes in the graph self.checkGraphModuleNodes(equalized_quantized_model, expected_node_list=node_list)
def prepare(self, model, configs, attrs): example_inputs = (torch.randn(1, 2), ) model.another_layer = prepare_fx(model.another_layer, configs[""], example_inputs) return model