コード例 #1
0
def return_first_non_observer_node(
    node: Node,
    gm: GraphModule,
) -> Node:
    """
    If node is not an observer, returns it.  If node is an observer,
    navigates up the graph and returns the first parent which is not an
    observer.  For example,

    graph: (node_non_obs), node = node_non_obs : returns node_non_obs
    graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
    graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
    """
    if node.op == "call_module":
        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
        if is_activation_post_process(node_obj):
            assert len(node.args) == 1
            assert isinstance(node.args[0], Node)
            node = node.args[0]
            # code duplication intended, not worth refactoring
            assert isinstance(node.target, str)
            node_obj = getattr_from_fqn(gm, node.target)
            if is_activation_post_process(node_obj):
                assert len(node.args) == 1
                assert isinstance(node.args[0], Node)
                node = node.args[0]
    return node
コード例 #2
0
def maybe_get_observer_for_node(
        node: Node,
        modules: Dict[str, torch.nn.Module]) -> Optional[torch.nn.Module]:
    """
    If the node is observed, return the observer
    instance. Otherwise, return None.
    """
    for maybe_obs_node, _ in node.users.items():
        if maybe_obs_node.op == 'call_module':
            maybe_obs = modules[str(maybe_obs_node.target)]
            if is_activation_post_process(maybe_obs):
                return maybe_obs
    return None
コード例 #3
0
def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
    fqn = None
    if hasattr(gm, '_node_name_to_scope'):
        # fqn on observers is not present, because they do not
        # exist when the fqns are created during tracing. If this is
        # an observer, get the fqn of the node being observed.
        node_to_use_for_fqn = node
        if node.op == 'call_module':
            assert isinstance(node.target, str)
            module = getattr_from_fqn(gm, node.target)
            if is_activation_post_process(module):
                node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
        fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][
            0]  # type: ignore[index]
    return fqn  # type: ignore[return-value]
コード例 #4
0
def generate_qconfig_map(
        root: torch.nn.Module, modules: Dict[str, torch.nn.Module],
        input_graph: Graph, qconfig_dict: Any,
        node_name_to_scope: Dict[str, Tuple[str,
                                            type]]) -> Dict[str, QConfigAny]:
    global_qconfig = qconfig_dict.get("", None)
    qconfig_map = dict()

    # example:
    #
    #   {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...}
    #
    # meaning in submodule 'foo.bar', we have seen 0 F.linear and
    # 1 F.conv2d invocations so far.
    submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = \
        defaultdict(lambda: defaultdict(int))
    for node in input_graph.nodes:
        qconfig = None
        if node.op == "get_attr":
            module_name, _ = _parent_name(node.target)
            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
                qconfig_dict, type(modules[module_name]), module_name,
                global_qconfig)
            qconfig_with_device_check = add_module_to_qconfig_obs_ctr(
                qconfig, modules.get(node.target, None))
        elif node.op == "call_function":
            # precedence: module_name_qconfig
            # > function_qconfig > global_qconfig
            # module_name takes precedence over function qconfig
            function_qconfig = get_object_type_qconfig(qconfig_dict,
                                                       node.target,
                                                       global_qconfig)
            module_path, module_type = node_name_to_scope[node.name]
            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
                qconfig_dict, module_type, module_path, function_qconfig)

            cur_object_type_idx = \
                submodule_to_object_type_to_cur_idx[module_path][node.target]
            submodule_to_object_type_to_cur_idx[module_path][node.target] += 1
            qconfig = maybe_adjust_qconfig_for_module_name_object_type_order(
                qconfig_dict, module_path, node.target, cur_object_type_idx,
                qconfig)
            qconfig_with_device_check = add_module_to_qconfig_obs_ctr(
                qconfig, modules.get(node.target, None))

        elif node.op == "call_method":
            module_path, module_type = node_name_to_scope[node.name]
            # first use node.target (string) to get the qconfig
            # this is to support configs like
            # "object_type": [("reshpe", qconfig)]
            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
                qconfig_dict, node.target, module_path, global_qconfig)
            # if there is no special config for the method, we'll fall back to the
            # config for the module that contains the call_method node
            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
                qconfig_dict, module_type, module_path, qconfig)
            # currently call_method does not support modifying qconfig
            # by order, we can add this later if it is needed.
            qconfig_with_device_check = add_module_to_qconfig_obs_ctr(
                qconfig, modules.get(node.target, None))

        elif node.op == 'call_module':
            # if the node is an observer, just continue - don't add it to the qconfig_map
            if is_activation_post_process(modules[node.target]):
                continue
            qconfig = maybe_adjust_qconfig_for_module_type_or_name(
                qconfig_dict, type(modules[node.target]), node.target,
                global_qconfig)

            module_path, module_type = node_name_to_scope[node.name]
            # Note: for call_module, the module_path is the current module's name.
            # to meaningfully count invocations, we need to count them in the parent
            # module.
            parent_name, _ = _parent_name(module_path)
            cur_object_type_idx = \
                submodule_to_object_type_to_cur_idx[parent_name][module_type]
            submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1
            qconfig = maybe_adjust_qconfig_for_module_name_object_type_order(
                qconfig_dict, parent_name, module_type, cur_object_type_idx,
                qconfig)
            qconfig_with_device_check = add_module_to_qconfig_obs_ctr(
                qconfig, modules.get(node.target, None))

            # regex is not supported eager mode propagate_qconfig_, we'll
            # need to set the qconfig explicitly here in case regex
            # is used
            modules[node.target].qconfig = qconfig_with_device_check
        else:
            qconfig_with_device_check = None

        qconfig_map[node.name] = qconfig_with_device_check
    return qconfig_map
コード例 #5
0
ファイル: utils.py プロジェクト: JonghyunBae/FlashNeuron
def all_node_args_have_no_tensors(node: Node, modules: Dict[str,
                                                            torch.nn.Module],
                                  cache: Dict[Node, bool]) -> bool:
    """
    If we know for sure that all of this node's args have no
    tensors (are primitives), return True.  If we either
    find a tensor or are not sure, return False. Note: this
    function is not exact.
    """
    if cache and node in cache:
        return cache[node]

    result = False  # will be overwritten
    if not isinstance(node, Node):
        result = True
    elif node.op == 'placeholder':
        result = False
    elif node.op == 'call_module':
        assert isinstance(node.target, str)
        if is_activation_post_process(modules[node.target]):
            result = all_node_args_have_no_tensors(
                node.args[0], modules, cache)  # type: ignore[arg-type]
    elif node.op == 'call_module':
        result = False
    elif node.op == 'call_function' and node.target is operator.getitem:
        result = all_node_args_have_no_tensors(node.args[0], modules,
                                               cache)  # type: ignore[arg-type]
    elif node.op == 'get_attr':
        result = False
    elif node.target is getattr and node.args[1] in ['ndim', 'shape']:
        # x1 = x0.ndim
        result = True
    elif node.op == 'call_method' and node.target == 'size':
        # x1 = x0.size(0)
        result = True
    else:
        found_one_tensor = False
        for arg in node.args:
            if isinstance(arg, list):
                for list_el in arg:
                    if isinstance(list_el, Node):
                        this_list_el_args_have_no_tensors = \
                            all_node_args_have_no_tensors(list_el, modules, cache)
                        found_one_tensor = found_one_tensor or \
                            (not this_list_el_args_have_no_tensors)
                        # If found_one_tensor is True, there is no point in
                        # recursing further as the end result will always
                        # be True.
                        # TODO(future PR): remove this entire function  and
                        # change to dtype inference without recursion.
                        if found_one_tensor:
                            result = not found_one_tensor
                            if cache:
                                cache[node] = result
                            return result
            elif isinstance(arg, int):
                pass
            else:
                if isinstance(arg, Node):
                    this_arg_args_have_no_tensors = all_node_args_have_no_tensors(
                        arg, modules, cache)
                    found_one_tensor = found_one_tensor or \
                        (not this_arg_args_have_no_tensors)
                    # If found_one_tensor is True, there is no point in
                    # recursing further as the end result will always
                    # be True.
                    # TODO(future PR): remove this entire function  and
                    # change to dtype inference without recursion.
                    if found_one_tensor:
                        result = not found_one_tensor
                        if cache:
                            cache[node] = result
                        return result
                else:
                    found_one_tensor = True
            result = not found_one_tensor
    if cache:
        cache[node] = result
    return result
コード例 #6
0
def _convert_do_not_use(
        model: GraphModule,
        is_reference: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None,
        is_standalone_module: bool = False,
        _remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
    """
    We will convert an observed model (a module with observer calls) to a reference
    quantized model, the rule is simple:
    1. for each observer module call in the graph, we'll convert it to calls to
       quantize and dequantize functions based on the observer instance
    2. for weighted operations like linear/conv, we need to convert them to reference
       quantized module, this requires us to know whether the dtype configured for the
       weight is supported in the backend, this is done in prepare step and the result
       is stored in observed_node_names, we can decide whether we need to swap the
       module based on this set

    standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(
        model)
    qconfig_map: Dict[
        str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    assert is_reference, "_convert_do_not_use only supports reference option"

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    matches = find_matches(model.graph,
                           modules,
                           patterns,
                           qconfig_map,
                           custom_module_classes=custom_module_classes)

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def replace_observer_with_quantize_dequantize_node(
            graph: Graph, node: Node, modules: Dict[str,
                                                    torch.nn.Module]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        root_module = modules[""]
        if observer_module.dtype == torch.float32:
            # remove the node for now
            # TODO: support dynamic quant
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        elif observer_module.dtype in [
                torch.quint8, torch.qint8, torch.float16
        ]:
            node_type, quantize_op, qparams = get_quantize_node_info(
                observer_module)
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(
                            root_module, graph, key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op,
                                                   tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize",
                                                     args=(quantized_node, ))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                # Inputs are assumed to be quantized if the user specifid the
                # input_quantized_idxs override.
                # TODO: remove the quantize node for the placeholder
                raise Exception("input_quantized_idxs is not supported yet")
        elif node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                # TODO: remove dequantize node if any
                raise Exception("output_quantized_idxs is not supported yet")
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                replace_observer_with_quantize_dequantize_node(
                    model.graph, node, modules)
            elif type(modules[node.target]) in set(
                    WEIGHTED_MODULE_CLASSES).union(QAT_MODULE_CLASSES).union(
                        FUSED_MODULE_CLASSES):
                # TODO: refactor this part to a function
                original_module = modules[node.target]
                qconfig = original_module.qconfig

                is_observed = node.name in observed_node_names
                is_weight_quantized = weight_is_statically_quantized(qconfig)
                # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
                if qconfig is None or not is_observed or not is_weight_quantized:
                    continue

                float_module = original_module
                fused_module = None
                if isinstance(original_module, QAT_MODULE_CLASSES):
                    # case 1. converting qat module to
                    # a float module, we need to attch
                    # weight fake_quant to the module,
                    # weight fake_quant is assumed to be run during
                    # QAT so we don't need to run it again here
                    float_module = original_module.to_float(
                    )  # type: ignore[operator]
                    # change qat conv to conv
                    parent_name, name = _parent_name(node.target)
                    setattr(modules[parent_name], name, float_module)
                    if isinstance(float_module,
                                  torch.nn.intrinsic._FusedModule):
                        fused_module = float_module
                        float_module = fused_module[0]
                    weight_post_process = original_module.weight_fake_quant
                else:
                    # case 2. converting a float module/fused float module
                    # to float module, we need to attach
                    # weight observer to the conv module and run it
                    # with conv weight
                    if isinstance(original_module,
                                  torch.nn.intrinsic._FusedModule):
                        fused_module = original_module
                        float_module = fused_module[0]  # type: ignore[index]
                    assert qconfig is not None
                    weight_post_process = qconfig.weight()
                    # run weight observer
                    weight_post_process(
                        float_module.weight)  # type: ignore[operator]
                weight_qparams = get_qparam_dict(weight_post_process)
                ref_qmodule_cls = get_static_quant_module_class(
                    type(float_module), is_reference=True)
                ref_qmodule = ref_qmodule_cls.from_float(
                    float_module, weight_qparams)
                if fused_module is not None:
                    fused_module[0] = ref_qmodule
                else:
                    parent_name, name = _parent_name(node.target)
                    setattr(modules[parent_name], name, ref_qmodule)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(
        convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, model.graph, preserved_attributes)
    return model
コード例 #7
0
def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
    return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
        is_activation_post_process(modules[str(node.target)])
コード例 #8
0
ファイル: convert.py プロジェクト: yuguo68/pytorch
def convert(
        model: GraphModule, is_reference: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None,
        is_standalone_module: bool = False,
        _remove_qconfig_flag: bool = True,
        convert_qconfig_dict: Dict[str, Any] = None,
        backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
    """
    We will convert an observed model (a module with observer calls) to a reference
    quantized model, the rule is simple:
    1. for each observer module call in the graph, we'll convert it to calls to
       quantize and dequantize functions based on the observer instance
    2. for weighted operations like linear/conv, we need to convert them to reference
       quantized module, this requires us to know whether the dtype configured for the
       weight is supported in the backend, this is done in prepare step and the result
       is stored in observed_node_names, we can decide whether we need to swap the
       module based on this set

    standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model)
    qconfig_map: Dict[str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    # TODO this should be removed now that gpu support for quantization is being supported.
    # however in practice, as of 7/22/2021, certain functions that get called by convert expect
    # only cpu arguments.
    # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda',
    # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend.
    if not is_reference:
        model.cpu()

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    # TODO refactor this code once we update the prepare logic to have additional information on
    # which graph nodes have been observed and share that with convert to decide which observers to ignore.
    if convert_qconfig_dict:
        prepare_qconfig_dict: Dict[str, Dict[Any, Any]] = model._qconfig_dict  # type: ignore[assignment]
        modules_copy = copy.deepcopy(modules)
        convert_dict_to_ordered_dict(convert_qconfig_dict)
        if model._is_qat:
            convert_qconfig_dict = update_qconfig_for_qat(convert_qconfig_dict, {})
        convert_qconfig_dict = update_qconfig_for_fusion(model, convert_qconfig_dict)

        compare_prepare_convert_qconfig_dict(prepare_qconfig_dict, convert_qconfig_dict)  # type: ignore[arg-type]
        convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, convert_qconfig_dict, node_name_to_scope)
        # check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map
        # or are set to None in the convert_qconfig_map.
        for k, v in qconfig_map.items():
            assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(k)
            if convert_qconfig_map[k] is not None:
                assert qconfig_equals(v, convert_qconfig_map[k]), 'Expected k {} to have the same value in prepare qconfig_dict \
                and convert qconfig_dict, found {} updated to {}.'.format(k, v, convert_qconfig_map[k])
        qconfig_map = convert_qconfig_map

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    # always run weight observers in the top level forward method
    # for dynamic quant ops or weight only quant ops
    run_weight_observers(model)

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    # TODO: move this outside of this function
    def replace_observer_with_quantize_dequantize_node(
            model: torch.nn.Module,
            graph: Graph,
            node: Node,
            modules: Dict[str, torch.nn.Module],
            node_name_to_scope: Dict[str, Tuple[str, type]],
            qconfig_map: Dict[str, QConfigAny]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map)
        observer_module = modules[node.target]
        maybe_quantize_node_info = get_quantize_node_info(observer_module)
        # Skip replacing observers to quant/dequant nodes if the qconfigs of all
        # consumers and producers of this observer are None
        skip_replacement = all([
            has_none_qconfig(n, qconfig_map) for n in
            list(node.args) + list(node.users.keys())])
        if skip_replacement or maybe_quantize_node_info is None:
            # didn't find correponding quantize op and info for the observer_module
            # so we just remove the observer
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        else:
            # otherwise, we can convert the observer moduel call to quantize/dequantize node
            node_type, quantize_op, qparams = maybe_quantize_node_info
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    # TODO: we can add the information of whether a value needs to
                    # be registered as an attribute in qparams dict itself
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)

    # this is a temporary hack for custom module, we may want to implement
    # this properly after the custom module class design is finalized
    def replace_observer_with_dequantize_node(node: Node, graph: Graph):
        call_custom_module_node = node.args[0]
        assert isinstance(call_custom_module_node, Node), \
            f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
        node.replace_all_uses_with(call_custom_module_node)
        graph.erase_node(node)
        insert_dequantize_node(call_custom_module_node, graph)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    if backend_config_dict is None:
        backend_config_dict = get_native_backend_config_dict()
    root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config_dict)
    # convert tuples so that it can work with isinstance(module, tuple_of_classes)
    root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
    qat_module_classes = get_qat_module_classes(backend_config_dict)
    fused_module_classes = get_fused_module_classes(backend_config_dict)
    statically_quantized_custom_module_nodes: Set[Node] = set()

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                # Inputs are assumed to be quantized if the user specifid the
                # input_quantized_idxs override.
                # we need to dequantize the inputs since all operators took
                # floating point inputs in reference quantized models
                insert_dequantize_node(node, model.graph)
        elif node.op == "output":
            # If the argument is empty we don't need to do anything
            if len(output_quantized_idxs) == 0:
                continue
            # Result are kept quantized if the user specified the
            # output_quantized_idxs override.
            # Remove the dequantize operator for the node in the end if any
            return_node = node
            output = node.args[0]
            # outputs can be Node, list, tuple, dict, other cases are not supported yet
            if isinstance(output, (list, tuple)):
                for idx in output_quantized_idxs:
                    maybe_recursive_remove_dequantize(output[idx], return_node, model.graph)
            elif isinstance(output, (Node, dict)):
                # we treat dict as a single argument currently, but it can be extended
                # to support {"key": dtype} after we change output_quantized_idxs to
                # dict
                if 0 in output_quantized_idxs:
                    maybe_recursive_remove_dequantize(output, return_node, model.graph)
            else:
                warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}")
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                observed_node = node.args[0]
                if observed_node in statically_quantized_custom_module_nodes:
                    replace_observer_with_dequantize_node(node, model.graph)
                else:
                    replace_observer_with_quantize_dequantize_node(
                        model, model.graph, node, modules, node_name_to_scope,
                        qconfig_map)
            elif is_observed_standalone_module(modules[node.target]):
                convert_standalone_module(
                    node, modules, model, is_reference, backend_config_dict)
            elif type(modules[node.target]) in set(
                    root_module_classes).union(qat_module_classes).union(fused_module_classes):
                # extra check for fused module classes to make sure they are fused module classes
                # of target modules
                if type(modules[node.target]) in fused_module_classes and \
                   type(modules[node.target][0]) not in root_module_classes:
                    continue
                convert_weighted_module(
                    node, modules, observed_node_names, qconfig_map, backend_config_dict)
            elif type(modules[node.target]) in custom_module_classes:
                convert_custom_module(
                    node, model.graph, modules, custom_module_class_mapping,
                    statically_quantized_custom_module_nodes)

    preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, copy.deepcopy(model.graph), preserved_attributes)

    # remove deadcode after converting observers to quant/dequant ops
    model.graph.eliminate_dead_code()
    model.recompile()

    # TODO: maybe move this to quantize_fx.py
    if not is_reference:
        model = duplicate_dequantize_node(model)
        model = duplicate_quantize_dynamic_node(model)
        model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope)
        model = remove_quant_dequant_pairs(model)
        model = remove_extra_dequantize(model)
    # TODO: this looks hacky, we want to check why we need this and see if we can
    # remove this
    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    return model
コード例 #9
0
def convert(model: GraphModule,
            is_reference: bool = False,
            convert_custom_config_dict: Dict[str, Any] = None,
            is_standalone_module: bool = False,
            _remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
    """ standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict = restore_state(
        model)
    qconfig_map: Dict[
        str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    # TODO this should be removed now that gpu support for quantization is being supported.
    # however in practice, as of 7/22/2021, certain functions that get called by convert expect
    # only cpu arguments.
    # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda',
    # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend.
    if not is_reference:
        model.cpu()

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    matches = find_matches(model.graph,
                           modules,
                           patterns,
                           qconfig_map,
                           custom_module_classes=custom_module_classes)

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    # always run weight observers in the top level forward method
    # for dynamic quant ops or weight only quant ops
    run_weight_observers(model)

    quantized_graph = Graph()
    env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict(
        lambda: defaultdict(Node))  # type: ignore[arg-type]

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def load_non_quantized(n: Node) -> Node:
        assert n.name in env, \
            'trying to load float node but did not find ' + \
            'node:' + n.name + \
            ' in env: ' + \
            str(env)
        dtype_to_node = env[n.name]
        if torch.float in dtype_to_node:
            return dtype_to_node[torch.float]
        elif None in dtype_to_node:
            return dtype_to_node[None]
        else:
            quantized_node = None
            for dtype in [torch.quint8, torch.qint8, torch.float16]:
                if dtype in dtype_to_node:
                    quantized_node = dtype_to_node[dtype]
                    break
            assert quantized_node is not None, "Did not find a supported quantized dtype:{}".format(
                dtype_to_node)
            env[n.name][torch.float] = Proxy(quantized_node).dequantize().node
            return env[n.name][torch.float]

    def load_quantized(dtype: torch.dtype):
        def load_quantized_impl(n: Node):
            assert n.name in env, \
                'trying to load quantized node but did not find node:' + \
                n.name + ' in environment:' + str(env)
            dtype_to_node = env[n.name]
            local_dtype: Optional[torch.dtype] = dtype
            if local_dtype == torch.float and local_dtype not in dtype_to_node:
                local_dtype = None
            if local_dtype in [torch.float, None]:
                return load_non_quantized(n)
            assert local_dtype in dtype_to_node, f'Expecting {dtype} in {dtype_to_node}'
            return dtype_to_node[local_dtype]

        return load_quantized_impl

    def load_x(n: Node) -> Node:
        assert n.name in env, \
            'node ' + n.name + ' does not exist in environment'
        dtype_to_node = env[n.name]
        dtypes = [
            torch.quint8, torch.qint8, torch.float16, torch.float32, None
        ]
        for dtype in dtypes:
            if dtype in dtype_to_node:
                return dtype_to_node[dtype]
        raise Exception(
            f'dtype {dtype} not found in environment: {dtype_to_node} for node {n.name}'
        )

    def load_arg(
        quantized: Optional[Union[List[int], Dict[int, torch.dtype],
                                  torch.dtype, Tuple[int, ...]]]
    ) -> Callable[[Node], Argument]:
        """
        Input: quantized, which can be None, torch.dtype, list or tuple
          - if quantized is None, then we'll load the node as long as it
            exists
          - if quantized is a dtype, then all args will be
            quantized to the specific dtype
          - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=torch.float)
          - if quantized is a list or tuple, then arg should be a list and
            the args with corresponding indexes will be quantized to torch.quint8


        Output: fn which takes arg_or_args, and loads them from the
            corresponding environment depending on the value of quantized.
        """
        assert quantized is None or \
            isinstance(quantized, (tuple, list, dict, torch.dtype)), type(quantized)
        if isinstance(quantized, (tuple, list, dict)) and len(quantized) == 0:
            # empty tuple or list means nothing is quantized
            quantized = torch.float

        def load_arg_impl(arg_or_args):
            # we'll update the format of `quantized`
            # to better match arg_or_args
            updated_quantized: Optional[Union[List[int], torch.dtype,
                                              Dict[int, torch.dtype],
                                              Tuple[int, ...]]] = quantized

            if isinstance(quantized, (tuple, list)) and \
               len(quantized) == 1 and isinstance(arg_or_args, Node):
                # when argument is one Node instead of tuple, we just need to check
                # 0 is in the quantized list
                if 0 in quantized:
                    updated_quantized = torch.quint8

            if updated_quantized is None:
                return map_arg(arg_or_args, load_x)
            if isinstance(updated_quantized, torch.dtype):
                return map_arg(arg_or_args, load_quantized(updated_quantized))
            elif isinstance(updated_quantized, (tuple, list)):
                assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                loaded_args = []
                # for now, we only support quantizing positional arguments
                for i, a in enumerate(arg_or_args):
                    if i in updated_quantized:
                        # Currently it's hardcoded to torch.quint8, we can extend this
                        # in the future to support all quantized
                        # dtypes
                        loaded_args.append(
                            map_arg(a, load_quantized(torch.quint8)))
                    else:
                        loaded_args.append(map_arg(a, load_non_quantized))
                return type(arg_or_args)(loaded_args)
            elif isinstance(updated_quantized, dict):
                loaded_args = []
                for i, a in enumerate(arg_or_args):
                    if i in updated_quantized:
                        loaded_args.append(
                            map_arg(a, load_quantized(updated_quantized[i])))
                    else:
                        loaded_args.append(map_arg(a, load_non_quantized))
                return type(arg_or_args)(loaded_args)

        return load_arg_impl

    def node_arg_is_quantized(node_arg: Any) -> bool:
        if isinstance(node_arg, Node):
            assert node_arg.name in env, \
                'Expecting node_arg to be in the environment'
            if node_arg.name in env:
                dtype_to_node = env[node_arg.name]
                return any([
                    x in dtype_to_node
                    for x in [torch.quint8, torch.qint8, torch.float16]
                ])
            else:
                return False
        elif isinstance(node_arg, list):
            quantized = map(node_arg_is_quantized, node_arg)
            if all(quantized):
                return True
            elif not any(quantized):
                return False
            else:
                raise Exception(
                    "partially quantized inputs in list not handled yet")
        else:
            return False

    def is_output_quantized(node: Node, obj: QuantizeHandler,
                            qconfig: QConfigAny,
                            modules: Dict[str, torch.nn.Module]) -> bool:
        """ Check if output node is quantized or not """
        assert modules is not None
        # for some ops the output is quantized only when `is_reference` is True
        # and when `is_reference` is False, it has limited qconfig
        # support, for example `add`
        # ideally this check should not happen here, it should happen either in
        # prepare or during lowering, we don't need this check
        # after the default path is changed to produce reference patterns
        quantized = obj.is_output_quantized(qconfig)

        # Need to get correct quantized/non-quantized state forn the output
        # of FixedQParamsQuantizeHandler
        # TODO: we may want to try to remove the special case here
        # as well
        if obj.should_mark_output_quantized_from_input_quantized_status(
                qconfig):
            assert node.op in [
                'call_module',
                'call_function',
                'call_method'], \
                'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
            # TODO: need to extend this to consider all relevant args instead of just arg[0]
            quantized = node_arg_is_quantized(node.args[0])

        # the output is unquantized if the node is not a CopyNode
        # or the activation is not statically quantized
        if not activation_is_statically_quantized(qconfig) or \
           not obj.input_output_observed():
            quantized = False
        if node_return_type_is_int(node):
            quantized = False

        return quantized

    def insert_quantize_node(node: Node,
                             modules: Dict[str, torch.nn.Module]) -> None:
        """ Given a activation_post_process module call node, insert a
        quantize node"""
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        prev_node = node.args[0]
        if observer_module.dtype == torch.float32:
            # copy the observer for fp32 dtype
            env[node.name][torch.float] = quantized_graph.node_copy(
                node, load_non_quantized)
        elif isinstance(prev_node, Node) and prev_node.name in env:
            # if previous node is already quantized, we'll just remove the
            # activation_post_process
            prev_dtype_to_node: Dict[Optional[torch.dtype],
                                     Node] = env[prev_node.name]
            current_dtype: Optional[
                torch.
                dtype] = observer_module.dtype  # type: ignore[assignment]
            if current_dtype in prev_dtype_to_node:
                env[node.
                    name][current_dtype] = prev_dtype_to_node[current_dtype]
            else:
                root_module = modules[""]
                assert isinstance(prev_node, Node)
                observer_dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
                env[node.name][observer_dtype] = \
                    quantize_node(
                        load_non_quantized(prev_node),
                        observer_module, node, modules, quantized_graph,
                        node_name_to_scope, is_input=True)
        else:
            # replace activation post process with quantization ops
            root_module = modules[""]
            assert isinstance(node.args[0], Node)
            dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
            env[node.name][dtype] = \
                quantize_node(
                    load_non_quantized(node.args[0]),
                    observer_module, node, modules,
                    quantized_graph,
                    node_name_to_scope, is_input=True)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    for node in model.graph.nodes:
        if node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                graph_output = map_arg(node.args[0], load_x)
            else:
                graph_output = map_arg(node.args[0], load_non_quantized)
            quantized_graph.output(graph_output)
            continue
        root_node, matched, matched_pattern, obj, qconfig = \
            matches.get(node.name, (None, None, None, None, None))
        if root_node is node:
            is_observed_standalone_module_node = (
                node.op == 'call_module'
                and is_observed_standalone_module(modules[node.target]))
            if qconfig is None and not is_observed_standalone_module_node:
                result = quantized_graph.node_copy(node, load_non_quantized)
                quantized = False
            else:
                assert obj is not None
                # We will get whether the output is quantized or not before
                # convert for standalone module and after convert
                # for non-standalone module, since _standalone_module_output_quantized_idxs
                # is only available in observed standalone module
                if is_observed_standalone_module_node:
                    out_quant_idxs = modules[
                        node.
                        target]._standalone_module_output_quantized_idxs.tolist(
                        )  # noqa: B950
                    assert len(
                        out_quant_idxs
                    ) <= 1, "Currently standalone only support one output"
                    quantized = 0 in out_quant_idxs

                qconfig = qconfig_map[node.name]
                # Note: load_arg can be overwritten in the convert method when used to
                # create Node in graph
                result = obj.convert(
                    node,
                    qconfig,
                    modules,
                    quantized_graph,
                    node_name_to_scope,
                    load_arg,
                    is_reference=is_reference,
                    convert_custom_config_dict=convert_custom_config_dict)
                if not is_observed_standalone_module_node:
                    quantized = is_output_quantized(node, obj, qconfig,
                                                    modules)

            if quantized:
                env[node.name][activation_dtype(qconfig)] = result
            else:
                env[node.name][torch.float] = result
            continue
        elif root_node is not None:
            if qconfig is None:
                # This branch is hit if all of these conditions are met:
                # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu)
                # 2. the current node is not the "root_node" of the pattern
                # 3. quantization for this pattern is disabled
                #
                # In this case, we need to make sure to populate the env with
                # intermediate nodes manually, because the QuantizeHandler.convert
                # function will not be called.
                result = quantized_graph.node_copy(node, load_non_quantized)
                env[node.name][torch.float] = result
            continue

        # handle activation post process calls
        if node.op == 'call_module' and \
                is_activation_post_process(modules[node.target]):
            insert_quantize_node(node, modules)
        elif node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                env[node.name][torch.quint8] = quantized_graph.node_copy(
                    node, load_non_quantized)
            else:
                env[node.name][torch.float] = \
                    quantized_graph.node_copy(node, load_non_quantized)
        else:
            # copy quantized or non-quantized node
            # get_tensor_info_node like shape works for both
            # quantized and non-quantized input and output a non-Tensor
            # (we use None for dtype currently for non-Tensors)
            if is_get_tensor_info_node(node):
                env[node.name][None] = \
                    quantized_graph.node_copy(node, load_x)
            else:
                env[node.name][torch.float] = \
                    quantized_graph.node_copy(node, load_non_quantized)

    # remove activation post process
    act_post_process_removed_graph = Graph()
    remove_env: Dict[str, Node] = {}

    def load_arg_remove(a: Argument) -> Argument:
        return map_arg(a, lambda node: remove_env[node.name])

    for node in quantized_graph.nodes:
        if node.op == 'output':
            act_post_process_removed_graph.output(
                map_arg(node.args[0], load_arg_remove))
            continue
        if node.op == 'call_module' and \
           is_activation_post_process(modules[node.target]):
            # remove activation post process node
            remove_env[node.name] = remove_env[node.args[0].name]
        else:
            remove_env[node.name] = act_post_process_removed_graph.node_copy(
                node, load_arg_remove)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(
        convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, act_post_process_removed_graph,
                                 preserved_attributes)
    if not is_reference:
        model = fold_weight(model, node_name_to_scope)
        model = lower_to_fbgemm(model)
    return model
コード例 #10
0
def _convert_do_not_use(
        model: GraphModule, is_reference: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None,
        is_standalone_module: bool = False,
        _remove_qconfig_flag: bool = True,
        backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
    """
    We will convert an observed model (a module with observer calls) to a reference
    quantized model, the rule is simple:
    1. for each observer module call in the graph, we'll convert it to calls to
       quantize and dequantize functions based on the observer instance
    2. for weighted operations like linear/conv, we need to convert them to reference
       quantized module, this requires us to know whether the dtype configured for the
       weight is supported in the backend, this is done in prepare step and the result
       is stored in observed_node_names, we can decide whether we need to swap the
       module based on this set

    standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model)
    qconfig_map: Dict[str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    assert is_reference, "_convert_do_not_use only supports reference option"

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def replace_observer_with_quantize_dequantize_node(graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        root_module = modules[""]
        if observer_module.dtype == torch.float32:
            # remove the node for now
            # TODO: support dynamic quant
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        elif observer_module.dtype in [torch.quint8, torch.qint8, torch.float16]:
            node_type, quantize_op, qparams = get_quantize_node_info(observer_module)
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(root_module, graph, key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)


    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    if backend_config_dict is None:
        backend_config_dict = {}
    quantized_reference_module_mapping = get_quantized_reference_module_mapping(backend_config_dict)
    # convert tuples so that it can work with isinstance(module, tuple_of_classes)
    weighted_module_classes = tuple(quantized_reference_module_mapping.keys())

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                # Inputs are assumed to be quantized if the user specifid the
                # input_quantized_idxs override.
                # we need to dequantize the inputs since all operators took
                # floating point inputs in reference quantized models
                insert_dequantize_node(node, model.graph)
        elif node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                # Remove the dequantize operator in the end
                maybe_dequantize_node = node.args[0]
                if isinstance(maybe_dequantize_node, Node) and \
                   maybe_dequantize_node.op == "call_method" and \
                   maybe_dequantize_node.target == "dequantize":
                    quantize_node = maybe_dequantize_node.args[0]
                    maybe_dequantize_node.replace_all_uses_with(quantize_node)
                    model.graph.erase_node(maybe_dequantize_node)
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                replace_observer_with_quantize_dequantize_node(model.graph, node, modules)
            elif is_observed_standalone_module(modules[node.target]):
                # TODO: move this to a separate function
                convert_standalone_module(node, modules, model, is_reference, backend_config_dict)

            elif type(modules[node.target]) in set(
                    weighted_module_classes).union(QAT_MODULE_CLASSES).union(FUSED_MODULE_CLASSES):
                convert_weighted_module(node, modules, observed_node_names, quantized_reference_module_mapping)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, model.graph, preserved_attributes)
    return model