def quantize(model, run_fn, run_args, mapping=None, inplace=False): r"""Quantize the input float model with post training static quantization. First it will prepare the model for calibration, then it calls `run_fn` which will run the calibration step, after that we will convert the model to a quantized model. Args: model: input float model run_fn: a calibration function for calibrating the prepared model run_args: positional arguments for `run_fn` inplace: carry out model transformations in-place, the original module is mutated mapping: correspondence between original module types and quantized counterparts Return: Quantized model. """ torch._C._log_api_usage_once("quantization_api.quantize.quantize") if mapping is None: mapping = get_default_static_quant_module_mappings() if not inplace: model = copy.deepcopy(model) model.eval() prepare(model, inplace=True) run_fn(model, *run_args) convert(model, mapping, inplace=True) return model
def add_d2_quant_mapping(mappings): """HACK: Add d2 specific module mapping for eager model quantization""" import torch.ao.quantization.quantization_mappings as qm for k, v in mappings.items(): if k not in qm.get_default_static_quant_module_mappings(): qm.DEFAULT_STATIC_QUANT_MODULE_MAPPINGS[k] = v if k not in qm.get_default_qat_module_mappings(): qm.DEFAULT_QAT_MODULE_MAPPINGS[k] = v
def convert(model: torch.nn.Module) -> torch.nn.Module: r"""Converts a prepared DBR quantization model to a quantized form. TODO(future PR): better docblock """ static_mappings = get_default_static_quant_module_mappings() dynamic_mappings = get_default_dynamic_quant_module_mappings() # swap the modules _swap_child_modules(model, static_mappings, dynamic_mappings) # add dynamic handling for quants/dequants, functions and methods model = add_auto_convert(model) return model
def _convert(module, mapping=None, inplace=False, is_reference=False, convert_custom_config_dict=None): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class Args: module: input module mapping: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules inplace: carry out model transformations in-place, the original module is mutated is_reference: a flag to enable quantized reference module """ if mapping is None: mapping = get_default_static_quant_reference_module_mappings() if is_reference \ else get_default_static_quant_module_mappings() if convert_custom_config_dict is None: convert_custom_config_dict = {} custom_module_class_mapping = convert_custom_config_dict.get( "observed_to_quantized_custom_module_class", {}) if not inplace: module = copy.deepcopy(module) reassign = {} for name, mod in module.named_children(): # both fused modules and observed custom modules are # swapped as one unit if not isinstance(mod, _FusedModule) and \ type_before_parametrizations(mod) not in custom_module_class_mapping: _convert( mod, mapping, True, # inplace is_reference, convert_custom_config_dict) reassign[name] = swap_module(mod, mapping, custom_module_class_mapping) for key, value in reassign.items(): module._modules[key] = value return module