示例#1
0
def convert(model: torch.nn.Module) -> torch.nn.Module:
    r"""Converts a prepared DBR quantization model to a quantized form.

    TODO(future PR): better docblock
    """
    static_mappings = get_default_static_quant_module_mappings()
    dynamic_mappings = get_default_dynamic_quant_module_mappings()
    # swap the modules
    _swap_child_modules(model, static_mappings, dynamic_mappings)
    # add dynamic handling for quants/dequants, functions and methods
    model = add_auto_convert(model)
    return model
示例#2
0
def quantize_dynamic(model,
                     qconfig_spec=None,
                     dtype=torch.qint8,
                     mapping=None,
                     inplace=False):
    r"""Converts a float model to dynamic (i.e. weights-only) quantized model.

    Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.

    For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
    by default is performed for layers with large weights size - i.e. Linear and RNN variants.

    Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
    If `qconfig` is provided, the `dtype` argument is ignored.

    Args:
        model: input model
        qconfig_spec: Either:

            - A dictionary that maps from name or type of submodule to quantization
              configuration, qconfig applies to all submodules of a given
              module unless qconfig for the submodules are specified (when the
              submodule already has qconfig attribute). Entries in the dictionary
              need to be QConfigDynamic instances.

            - A set of types and/or submodule names to apply dynamic quantization to,
              in which case the `dtype` argument is used to specify the bit-width

        inplace: carry out model transformations in-place, the original module is mutated
        mapping: maps type of a submodule to a type of corresponding dynamically quantized version
            with which the submodule needs to be replaced

    """
    torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
    if qconfig_spec is None:
        if dtype == torch.qint8:
            qconfig_spec = {
                nn.Linear: default_dynamic_qconfig,
                nn.LSTM: default_dynamic_qconfig,
                nn.GRU: default_dynamic_qconfig,
                nn.LSTMCell: default_dynamic_qconfig,
                nn.RNNCell: default_dynamic_qconfig,
                nn.GRUCell: default_dynamic_qconfig,
            }
        elif dtype == torch.float16:
            qconfig_spec = {
                nn.Linear: float16_dynamic_qconfig,
                nn.LSTM: float16_dynamic_qconfig,
                nn.GRU: float16_dynamic_qconfig,
                nn.LSTMCell: float16_dynamic_qconfig,
                nn.RNNCell: float16_dynamic_qconfig,
                nn.GRUCell: float16_dynamic_qconfig,
            }
        elif dtype == torch.quint8:
            qconfig_spec = {
                nn.EmbeddingBag: float_qparams_weight_only_qconfig,
                nn.Embedding: float_qparams_weight_only_qconfig,
            }
        elif dtype == torch.quint4x2:
            qconfig_spec = {
                nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit,
            }
        else:
            raise ValueError(
                "Don't know how to quantize with default settings for {}. Provide full qconfig please"
                .format(dtype))
    elif isinstance(qconfig_spec, set):
        if dtype is torch.qint8:
            default_qconfig = default_dynamic_qconfig
        elif dtype is torch.float16:
            default_qconfig = float16_dynamic_qconfig
        elif dtype is torch.quint8:
            default_qconfig = float_qparams_weight_only_qconfig
        elif dtype is torch.quint4x2:
            default_qconfig = float_qparams_weight_only_qconfig_4bit
        else:
            raise RuntimeError(
                'Unknown dtype specified for quantize_dynamic: ', str(dtype))
        qconfig_spec = dict(
            zip(qconfig_spec, itertools.repeat(default_qconfig)))

    if mapping is None:
        mapping = get_default_dynamic_quant_module_mappings()

    if not inplace:
        model = copy.deepcopy(model)
    model.eval()
    propagate_qconfig_(model, qconfig_spec)
    convert(model, mapping, inplace=True)
    return model