Beispiel #1
0
def quantize(model, run_fn, run_args, mapping=None, inplace=False):
    r"""Quantize the input float model with post training static quantization.

    First it will prepare the model for calibration, then it calls
    `run_fn` which will run the calibration step, after that we will
    convert the model to a quantized model.

    Args:
        model: input float model
        run_fn: a calibration function for calibrating the prepared model
        run_args: positional arguments for `run_fn`
        inplace: carry out model transformations in-place, the original module is mutated
        mapping: correspondence between original module types and quantized counterparts

    Return:
        Quantized model.
    """
    torch._C._log_api_usage_once("quantization_api.quantize.quantize")
    if mapping is None:
        mapping = get_default_static_quant_module_mappings()
    if not inplace:
        model = copy.deepcopy(model)
    model.eval()
    prepare(model, inplace=True)
    run_fn(model, *run_args)
    convert(model, mapping, inplace=True)
    return model
Beispiel #2
0
def add_d2_quant_mapping(mappings):
    """HACK: Add d2 specific module mapping for eager model quantization"""
    import torch.ao.quantization.quantization_mappings as qm

    for k, v in mappings.items():
        if k not in qm.get_default_static_quant_module_mappings():
            qm.DEFAULT_STATIC_QUANT_MODULE_MAPPINGS[k] = v
        if k not in qm.get_default_qat_module_mappings():
            qm.DEFAULT_QAT_MODULE_MAPPINGS[k] = v
Beispiel #3
0
def convert(model: torch.nn.Module) -> torch.nn.Module:
    r"""Converts a prepared DBR quantization model to a quantized form.

    TODO(future PR): better docblock
    """
    static_mappings = get_default_static_quant_module_mappings()
    dynamic_mappings = get_default_dynamic_quant_module_mappings()
    # swap the modules
    _swap_child_modules(model, static_mappings, dynamic_mappings)
    # add dynamic handling for quants/dequants, functions and methods
    model = add_auto_convert(model)
    return model
Beispiel #4
0
def _convert(module,
             mapping=None,
             inplace=False,
             is_reference=False,
             convert_custom_config_dict=None):
    r"""Converts submodules in input module to a different module according to `mapping`
    by calling `from_float` method on the target module class

    Args:
        module: input module
        mapping: a dictionary that maps from source module type to target
                 module type, can be overwritten to allow swapping user defined
                 Modules
        inplace: carry out model transformations in-place, the original module
                 is mutated
        is_reference: a flag to enable quantized reference module

    """
    if mapping is None:
        mapping = get_default_static_quant_reference_module_mappings() if is_reference \
            else get_default_static_quant_module_mappings()
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    custom_module_class_mapping = convert_custom_config_dict.get(
        "observed_to_quantized_custom_module_class", {})

    if not inplace:
        module = copy.deepcopy(module)
    reassign = {}
    for name, mod in module.named_children():
        # both fused modules and observed custom modules are
        # swapped as one unit
        if not isinstance(mod, _FusedModule) and \
           type_before_parametrizations(mod) not in custom_module_class_mapping:
            _convert(
                mod,
                mapping,
                True,  # inplace
                is_reference,
                convert_custom_config_dict)
        reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)

    for key, value in reassign.items():
        module._modules[key] = value

    return module