Esempio n. 1
0
def compare_prepare_convert_qconfig_dict(
        prepare_qconfig_dict: Dict[str, Dict[Any, Any]],
        convert_qconfig_dict: Dict[str, Dict[Any, Any]]) -> None:
    r""" Compare the qconfig_dict passed in convert to the one from prepare and check the values

    Args:
      `prepare_qconfig_dict`: configuration dictionary for prepare quantization step
      `convert_qconfig_dict`: configuration dictionary for convert quantization step
    """
    prepare_keys = prepare_qconfig_dict.keys()
    convert_keys = convert_qconfig_dict.keys()

    for k in prepare_keys:
        if k == '':
            assert k in convert_qconfig_dict, "Missing key {} from convert qconfig_dict when it was present in prepare".format(
                k)
            assert (
                convert_qconfig_dict[k] is None or qconfig_equals(
                    prepare_qconfig_dict[k], convert_qconfig_dict[k])
            ), (  # type: ignore[arg-type]
                "Expected convert qconfig_dict have the same qconfig as prepare qconfig_dict or None."
                "Updated qconfig {} to {} for key {}".format(
                    prepare_qconfig_dict[k], convert_qconfig_dict[k], k))
        elif k in ['object_type', 'module_name', 'module_namr_regex']:
            for name, qconfig in prepare_qconfig_dict[k].items():
                assert name in convert_qconfig_dict[
                    k], "Missing key {} {} from convert qconfig_dict \
                when it was present in prepare".format(k, name)
                assert convert_qconfig_dict[k][name] is None \
                    or qconfig_equals(prepare_qconfig_dict[k][name], convert_qconfig_dict[k][name]), \
                    "Expected convert qconfig_dict have the same qconfig as prepare qconfig_dict or None. \
                    Updated qconfig {} to {} for key {} {}"                                                           .format(prepare_qconfig_dict[k], convert_qconfig_dict[k], k, name)
        else:
            assert "Unsupported key in convert_qconfig_dict {}".format(k)
Esempio n. 2
0
def compare_prepare_convert_qconfig_mappings(
        prepare_qconfig_mapping: QConfigMapping,
        convert_qconfig_mapping: QConfigMapping):
    r""" Compare the qconfig_mapping passed in convert to the one from prepare and check the values

    Args:
      `prepare_qconfig_mapping`: configuration for prepare quantization step
      `convert_qconfig_mapping`: configuration for convert quantization step
    """
    assert qconfig_equals(prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig), \
        "Expected global qconfigs to be the same in the prepare and convert quantization configs"
    prepare_dicts: List[OrderedDict] = [
        prepare_qconfig_mapping.object_type_qconfigs,
        prepare_qconfig_mapping.module_name_qconfigs,
        prepare_qconfig_mapping.module_name_regex_qconfigs,
    ]
    convert_dicts: List[OrderedDict] = [
        convert_qconfig_mapping.object_type_qconfigs,
        convert_qconfig_mapping.module_name_qconfigs,
        convert_qconfig_mapping.module_name_regex_qconfigs,
    ]
    dict_names = [
        OBJECT_TYPE_DICT_KEY, MODULE_NAME_DICT_KEY, MODULE_NAME_REGEX_DICT_KEY
    ]
    for i in range(len(prepare_dicts)):
        for name, qconfig in prepare_dicts[i].items():
            assert name in convert_dicts[
                i], "Missing key {} {} in convert QConfigMapping \
                when it was present in prepare".format(dict_names[i], name)
            assert convert_dicts[i][name] is None \
                or qconfig_equals(prepare_dicts[i][name], convert_dicts[i][name]), \
                "Expected convert QConfigMapping to have the same qconfig as prepare for key {} {}; \
                prepare: {}; convert: {}"                                         .format(dict_names[i], name, prepare_dicts[i][name], convert_dicts[i][name])
Esempio n. 3
0
    def test_embedding_qat_qconfig_equal(self):
        # Embedding QAT uses a NoopObserver class for activation,
        # and a FakeQuant for weight, make sure that qconfig comparison
        # functions properly for a mix of partial function and class in
        # qconfig.
        model = ManualEmbeddingBagLinear().train()
        model = prepare_qat(model)

        self.assertTrue(
            qconfig_equals(model.emb.qconfig, default_embedding_qat_qconfig))
Esempio n. 4
0
def update_qconfig_for_fusion(
    model: GraphModule,
    qconfig_dict: Any,
) -> Any:
    """
    Update the qconfig_dict to account for fused modules such as LinearReLU.
    """
    object_type_dict = qconfig_dict.get("object_type", None)
    if object_type_dict is None:
        return qconfig_dict

    modules = dict(model.named_modules())

    for node in model.graph.nodes:
        if node.op == 'call_module' and node.target in modules:
            maybe_fused_module = modules[str(node.target)]
            if not isinstance(maybe_fused_module, _FusedModule):
                continue

            ops = list(maybe_fused_module._modules.values())
            fused_qconfig = object_type_dict.get(type(ops[0]), None)

            # Raise an error if the modules in the fused module have
            # different qconfigs specified in the qconfig_dict
            # TODO: currently it only works for modules,
            # need to make this work for torch.nn.functional.relu
            # TODO: currently it only works for object_type configurations,
            # ideally it should work for different types of configurations,
            # maybe we want to redesign this part
            for op in ops[1:]:
                if not qconfig_equals(object_type_dict.get(type(op), None),
                                      fused_qconfig):
                    raise LookupError(
                        "During fusion, we need to specify the same " +
                        f"qconfigs for all module types in {type(maybe_fused_module)} "
                        + f"offending type: {type(op)}")

            if fused_qconfig is not None:
                object_type_dict[type(maybe_fused_module)] = fused_qconfig

    return qconfig_dict
Esempio n. 5
0
def update_qconfig_for_fusion(
    model: GraphModule,
    qconfig_dict: Any,
) -> Any:
    """
    Update the qconfig_dict to account for fused modules such as LinearReLU.
    """
    object_type_dict = qconfig_dict.get("object_type", None)
    if object_type_dict is None:
        return qconfig_dict

    modules = dict(model.named_modules())

    for node in model.graph.nodes:
        if node.op == 'call_module' and node.target in modules:
            module_type = type(modules[str(node.target)])
            if module_type not in list(
                    DEFAULT_OP_LIST_TO_FUSER_METHOD.values()):
                continue

            for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items():
                if module_type == fuser:
                    fused_qconfig = object_type_dict.get(ops[0], None)

                    # Raise an error if the modules in the fused module have
                    # different qconfigs specified in the qconfig_dict
                    for op in ops:
                        if not qconfig_equals(object_type_dict.get(op, None),
                                              fused_qconfig):
                            raise LookupError(
                                "During fusion, we need to specify the same " +
                                f"qconfigs for both modules in {module_type}.")

                    if fused_qconfig is not None:
                        object_type_dict[module_type] = fused_qconfig

    return qconfig_dict