Ejemplo n.º 1
0
def bias_correction(float_model,
                    quantized_model,
                    img_data,
                    target_modules=_supported_modules_quantized,
                    neval_batches=None):
    ''' Using numeric suite shadow module, the expected output of the floating point and quantized modules
    is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused
    by quantization
    Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2)

    Args:
        float_model: a trained model that serves as a reference to what bias correction should aim for
        quantized_model: quantized form of float_model that bias correction is to applied to
        img_data: calibration data to estimate the expected output (used to find quantization error)
        target_modules: specifies what submodules in quantized_model need bias correction (can be extended to
                unquantized submodules)
        neval_batches: a cap to the number of batches you want to be used for estimating the expected output
    '''
    ns.prepare_model_with_stubs(float_model, quantized_model,
                                _supported_modules, MeanShadowLogger)

    uncorrected_modules = {}
    for name, submodule in quantized_model.named_modules():
        if type(submodule) in target_modules:
            uncorrected_modules[name] = submodule

    for uncorrected_module in uncorrected_modules:
        quantized_submodule = get_module(quantized_model, uncorrected_module)
        bias = get_param(quantized_submodule, 'bias')
        if bias is not None:

            count = 0
            for data in img_data:
                quantized_model(data[0])
                count += 1
                if count == neval_batches:
                    break
            ob_dict = ns.get_logger_dict(quantized_model)
            parent_name, _ = parent_child_names(uncorrected_module)

            float_data = ob_dict[parent_name + '.stats']['float']
            quant_data = ob_dict[parent_name + '.stats']['quantized']

            # math for expected_error
            quantization_error = quant_data - float_data
            dims = list(range(quantization_error.dim()))
            # Note: we don't want to take the mean over the output channel dimension
            dims.remove(1)
            expected_error = torch.mean(quantization_error, dims)

            updated_bias = bias.data - expected_error

            bias.data = updated_bias

            # Resets the data contained in the loggers
            for name, submodule in quantized_model.named_modules():
                if isinstance(submodule, MeanShadowLogger):
                    submodule.clear()
Ejemplo n.º 2
0
def compare_model_stub_fx(prepared_float_model,
                          q_model,
                          module_swap_list,
                          *data,
                          Logger=ShadowLogger):
    r"""Compare quantized module in a model with its floating point counterpart,
    feeding both of them the same input. Return a dict with key corresponding to
    module names and each entry being a dictionary with two keys 'float' and
    'quantized', containing the output tensors of quantized and its matching
    float shadow module. This dict can be used to compare and compute the module
    level quantization error.

    Note prepared_float module is a float module which has been prepared by calling prepare_fx.

    This function first call prepare_model_with_stubs_fx() to swap the quantized
    module that we want to compare with the Shadow module, which takes quantized
    module, corresponding float module and logger as input, and creates a forward
    path inside to make the float module to shadow quantized module sharing the
    same input. The logger can be customizable, default logger is ShadowLogger
    and it will save the outputs of the quantized module and float module that
    can be used to compute the module level quantization error.

    Example usage:
        module_swap_list = [nn.Linear]
        ob_dict = compare_model_stub_fx(prepared_float_model,qmodel,module_swap_list, data)
        for key in ob_dict:
            print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))

    Args:
        prepared_float_model: float model which has been prepared
        q_model: model quantized from float_model
        module_swap_list: list of float module types at which shadow modules will
            be attached.
        data: input data used to run the prepared q_model
        Logger: type of logger to be used in shadow module to process the outputs of
            quantized module and its float shadow module
    """
    torch._C._log_api_usage_once(
        "quantization_api._numeric_suite.compare_model_stub_fx")
    prepared_float_model = remove_qconfig_observer_fx(prepared_float_model)
    prepare_model_with_stubs_fx(prepared_float_model, q_model,
                                module_swap_list, Logger)
    q_model(*data)
    ob_dict = get_logger_dict(q_model)
    return ob_dict
Ejemplo n.º 3
0
print(ob_dict['layer1.0.stats']['float'][0].shape)
print(ob_dict['layer1.0.stats']['quantized'][0].shape)

##############################################################################
# This dict can be then used to compare and compute the module level quantization error.

for key in ob_dict:
    print(key, compute_error(ob_dict[key]['float'][0], ob_dict[key]['quantized'][0].dequantize()))

##############################################################################
# If we want to do the comparison for more than one input data, we can do the following.

ns.prepare_model_with_stubs(float_model, qmodel, module_swap_list, ns.ShadowLogger)
for data in img_data:
    qmodel(data[0])
ob_dict = ns.get_logger_dict(qmodel)

##############################################################################
# The default logger used in above APIs is ``ShadowLogger``, which is used to log the outputs of the quantized module and its matching float shadow module. We can inherit from base ``Logger`` class and create our own logger to perform different functionalities. For example we can make a new ``MyShadowLogger`` class as below.

class MyShadowLogger(ns.Logger):
    r"""Customized logger class
    """

    def __init__(self):
        super(MyShadowLogger, self).__init__()

    def forward(self, x, y):
        # Custom functionalities
        # ...
        return x