Exemple #1
0
def quantize(model, run_fn, run_args, mapping=None, inplace=False):
    r"""Quantize the input float model with post training static quantization.

    First it will prepare the model for calibration, then it calls
    `run_fn` which will run the calibration step, after that we will
    convert the model to a quantized model.

    Args:
        model: input float model
        run_fn: a calibration function for calibrating the prepared model
        run_args: positional arguments for `run_fn`
        inplace: carry out model transformations in-place, the original module is mutated
        mapping: correspondence between original module types and quantized counterparts

    Return:
        Quantized model.
    """
    torch._C._log_api_usage_once("quantization_api.quantize.quantize")
    if mapping is None:
        mapping = get_default_static_quant_module_mappings()
    if not inplace:
        model = copy.deepcopy(model)
    model.eval()
    prepare(model, inplace=True)
    run_fn(model, *run_args)
    convert(model, mapping, inplace=True)
    return model
Exemple #2
0
def add_d2_quant_mapping(mappings):
    """HACK: Add d2 specific module mapping for eager model quantization"""
    import torch.quantization.quantization_mappings as qm

    for k, v in mappings.items():
        if k not in qm.get_default_static_quant_module_mappings():
            qm.DEFAULT_STATIC_QUANT_MODULE_MAPPINGS[k] = v
        if k not in qm.get_default_qat_module_mappings():
            qm.DEFAULT_QAT_MODULE_MAPPINGS[k] = v
Exemple #3
0
def _convert(module,
             mapping=None,
             inplace=False,
             convert_custom_config_dict=None):
    r"""Converts submodules in input module to a different module according to `mapping`
    by calling `from_float` method on the target module class

    Args:
        module: input module
        mapping: a dictionary that maps from source module type to target
                 module type, can be overwritten to allow swapping user defined
                 Modules
        inplace: carry out model transformations in-place, the original module
                 is mutated

    """
    if mapping is None:
        mapping = get_default_static_quant_module_mappings()
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    custom_module_class_mapping = convert_custom_config_dict.get(
        "observed_to_quantized_custom_module_class", {})

    if not inplace:
        module = copy.deepcopy(module)
    reassign = {}
    for name, mod in module.named_children():
        # both fused modules and observed custom modules are
        # swapped as one unit
        if not isinstance(mod, _FusedModule) and \
           type(mod) not in custom_module_class_mapping:
            _convert(
                mod,
                mapping,
                True,  # inplace
                convert_custom_config_dict)
        reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)

    for key, value in reassign.items():
        module._modules[key] = value

    return module
Exemple #4
0
    def test_op_relationship_mapping(self):
        """
        Tests that the mapping of op relationships is complete.
        """
        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops(
        )
        type_a_related_to_b = \
            get_type_a_related_to_b(base_name_to_sets_of_related_ops)

        # 1. check static quant module mappings
        static_quant_mod_mappings = get_default_static_quant_module_mappings()
        for fp32_type, int8_type in static_quant_mod_mappings.items():
            # skip quants and dequants, for the purposes of Numerical Suite
            types_to_skip = (
                torch.quantization.QuantStub,
                torch.quantization.DeQuantStub,
                nnq.FloatFunctional,
            )
            if fp32_type in types_to_skip:
                continue

            # verify relatedness
            in_type_a_related_to_b = \
                (fp32_type, int8_type) in type_a_related_to_b
            self.assertTrue(
                in_type_a_related_to_b,
                f"{fp32_type} and {int8_type} need a relationship mapping")

        # 2. check static quant op mappings
        static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings(
        )
        for fp32_type, int8_type in static_quant_fun_mappings.items():
            # verify relatedness
            in_type_a_related_to_b = \
                (fp32_type, int8_type) in type_a_related_to_b
            self.assertTrue(
                in_type_a_related_to_b,
                f"{fp32_type} and {int8_type} need a relationship mapping")

        # 3. check dynamic quant mappings
        dynamic_quant_mappings = get_default_dynamic_quant_module_mappings()
        for fp32_type, int8_type in dynamic_quant_mappings.items():
            # TODO(future PR): enable correct weight extraction for these
            # and remove from this list.
            types_to_skip = (
                nn.GRUCell,
                nn.GRU,
                nn.LSTMCell,
                nn.RNNCell,
            )
            if fp32_type in types_to_skip:
                continue
            # verify relatedness
            in_type_a_related_to_b = \
                (fp32_type, int8_type) in type_a_related_to_b
            self.assertTrue(
                in_type_a_related_to_b,
                f"{fp32_type} and {int8_type} need a relationship mapping")

        # 4. go through the ops mapped to each QuantizeHandler type, and verify
        # correctness.
        def _op_in_base_sets_of_related_ops(op):
            for name, ops in base_name_to_sets_of_related_ops.items():
                if op in ops:
                    return True
            return False

        default_quant_patterns = get_default_quant_patterns()
        for pattern, qhandler_cls in default_quant_patterns.items():
            base_op = None
            if isinstance(pattern, tuple):
                base_op = pattern[-1]
            elif isinstance(pattern, str):
                # TODO(future PR): add handling for these
                continue
            else:
                base_op = pattern

            qhandler_cls_all_ops_quantizeable = [
                qp.CatQuantizeHandler,
                qp.ConvReluQuantizeHandler,
                qp.LinearReLUQuantizeHandler,
                qp.BatchNormQuantizeHandler,
                qp.EmbeddingQuantizeHandler,
                qp.RNNDynamicQuantizeHandler,
                qp.ELUQuantizeHandler,
            ]

            qhandler_cls_quant_op_same_signature = [
                qp.FixedQParamsOpQuantizeHandler,
                qp.CopyNodeQuantizeHandler,
            ]

            if qhandler_cls == qp.BinaryOpQuantizeHandler:
                # these ops do not have quantized equivalents
                ops_to_skip = [
                    torch.bmm, torch.sum, torch.div, torch.sub,
                    operator.truediv, operator.sub
                ]
                if base_op in ops_to_skip:
                    continue
                self.assertTrue(_op_in_base_sets_of_related_ops(base_op),
                                f"{base_op} not in sets of related ops")
            elif qhandler_cls == qp.RNNDynamicQuantizeHandler:
                # TODO(future PR): add support for all classes in
                # RNNDynamicQuantizeHandler
                pass
            elif qhandler_cls == qp.DefaultNodeQuantizeHandler:
                ops_to_skip = [
                    torch.nn.SiLU,
                    torch.nn.functional.silu,
                ]
                if base_op in ops_to_skip:
                    continue
                self.assertTrue(_op_in_base_sets_of_related_ops(base_op),
                                f"{base_op} not in sets of related ops")
            elif qhandler_cls in qhandler_cls_quant_op_same_signature:
                # these ops use the same op signature for fp32 and quantized
                # tensors
                pass
            elif qhandler_cls in qhandler_cls_all_ops_quantizeable:
                self.assertTrue(_op_in_base_sets_of_related_ops(base_op),
                                f"{base_op} not in sets of related ops")
            else:
                raise AssertionError(
                    f"handing for {qhandler_cls} not implemented")