Ejemplo n.º 1
    def _convert(self, model: GraphModule, debug: bool = False,
                 convert_custom_config_dict: Dict[str, Any] = None,
                 is_standalone_module: bool = False) -> GraphModule:
        """ standalone_module means it a submodule that is not inlined in
        parent module, and will be quantized separately as one unit.

        Returns a quantized standalone module which accepts float input
        and produces float output.
        if convert_custom_config_dict is None:
            convert_custom_config_dict = {}
        # always run weight observers in the top level forward method
        # for dynamic quant ops or weight only quant ops

        # move to cpu since we only have quantized cpu kernels
        self.modules = dict(model.named_modules())

        custom_module_classes = get_custom_module_class_keys(
        assert self.patterns is not None
        matches = self._find_matches(
            model.graph, self.modules, self.patterns,

        quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \
            self._find_quants(model.graph, matches)

        self.quantized_graph = Graph()
        env: Dict[str, Node] = {}
        quant_env: Dict[str, Node] = {}

        graph_inputs: List[str] = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':

        def load_non_quantized(n: Node) -> Node:
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find ' + \
                    'node:' + n.name + \
                    ' in quantized or non quantized environment, env: ' + \
                    str(env) + ' quant_env:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n: Node) -> Node:
            assert n.name in quant_env, \
                'trying to load quantized node but did not find node:' + \
                n.name + ' in quant environment:' + str(quant_env)
            return quant_env[n.name]

        def load_x(n: Node) -> Node:
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
                return env[n.name]

        def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]]
                     ) -> Callable[[Node], Argument]:
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and
                the args with corresponding indexes will be quantized
              - if quantized is a boolean, then all args will be
                quantized/not quantized
              - if quantized is None, then we'll load the node as long as it

            Output: fn which takes arg_or_args, and loads them from the
                corresponding environment depending on the value of quantized.
            assert quantized is None or \
                isinstance(quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)
            return load_arg_impl

        def node_arg_is_quantized(node_arg: Any) -> bool:
            if isinstance(node_arg, Node):
                assert node_arg.name in env or node_arg.name in quant_env, \
                    'Expecting node_arg to be in the environment'
                # there might be nodes appearing in both environemnts, but
                # quant_env will take precedence
                if node_arg.name in quant_env:
                    return True
                elif node_arg.name in env:
                    return False
                    return False
            elif isinstance(node_arg, list):
                quantized = map(node_arg_is_quantized, node_arg)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                    raise Exception(
                        "partially quantized inputs in list not handled yet")
                return False

        def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
            """ Check if output node is quantized or not """
            assert self.modules is not None
            # by default the output is expected to be quantized
            quantized = True

            # Need to get correct quantized/non-quantized state for the output
            # of CopyNode
            if type(obj) in [
                assert node.op in [
                    'call_method'], \
                    'CopyNode of type ' + node.op + ' is not handled'
                quantized = node_arg_is_quantized(node.args[0])

            if not activation_is_statically_quantized(qconfig) or \
               not input_output_observed(obj):
                quantized = False

            return quantized

        def insert_quantize_node(node: Node) -> None:
            """ Given a activation_post_process module call node, insert a
            quantize node"""
            assert self.modules is not None
            assert isinstance(node.target, str)
            observer_module = self.modules[node.target]
            prev_node = node.args[0]
            if observer_module.dtype == torch.float16:
                # activations are not quantized for
                # fp16 dynamic quantization
                # copy the activaiton_post_process node here
                # since we may need it when we insert prepack
                # op for weight of linear, this will be removed
                # later in a separate pass
                env[node.name] = self.quantized_graph.node_copy(
                    node, load_non_quantized)
            elif isinstance(prev_node, Node) and prev_node.name in quant_env:
                # if previous node is already quantized, we'll just remove the
                # activation_post_process
                quant_env[node.name] = quant_env[prev_node.name]
                # replace activation post process with quantization ops
                root_module = self.modules[""]
                assert isinstance(node.args[0], Node)
                quant_env[node.name] = quantize_node(
                    root_module, self.quantized_graph,
                    load_non_quantized(node.args[0]), observer_module)

        # additional state to override inputs to be quantized, if specified
        # by the user
        placeholder_node_seen_cnt = 0
        output_node_seen_cnt = 0
        input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "input_quantized_idxs", [])
        output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "output_quantized_idxs", [])

        for node in model.graph.nodes:
            if node.op == 'output':
                cur_output_node_idx = output_node_seen_cnt
                output_node_seen_cnt += 1
                if cur_output_node_idx in output_quantized_idxs:
                    # Result are kept quantized if the user specified the
                    # output_quantized_idxs override.
                    graph_output = map_arg(node.args[0], load_x)
                    graph_output = map_arg(node.args[0], load_non_quantized)
            root_node, matched, matched_pattern, obj, qconfig = \
                matches.get(node.name, (None, None, None, None, None))
            if root_node is node:
                if qconfig is None:
                    result = self.quantized_graph.node_copy(
                        node, load_non_quantized)
                    quantized = False
                    assert obj is not None
                    is_standalone_module_node = (
                        node.op == 'call_module' and
                            self.modules[node.target])  # type: ignore
                    result = obj.convert(
                        self, node, load_arg, debug=debug,
                    if is_standalone_module_node:
                        quantized = False
                        quantized = is_output_quantized(node, obj)

                if quantized:
                    quant_env[node.name] = result
                    env[node.name] = result
            elif root_node is not None:

            # handle activation post process calls
            if node.op == 'call_module' and \
            elif node.op == 'placeholder':
                cur_placeholder_node_idx = placeholder_node_seen_cnt
                placeholder_node_seen_cnt += 1
                if cur_placeholder_node_idx in input_quantized_idxs:
                    quant_env[node.name] = \
                        self.quantized_graph.node_copy(node, load_non_quantized)
                    env[node.name] = \
                        self.quantized_graph.node_copy(node, load_non_quantized)
                # copy quantized or non-quantized node
                env[node.name] = \
                    self.quantized_graph.node_copy(node, load_non_quantized)

        # remove activation post process
        act_post_process_removed_graph = Graph()
        env = {}

        def load_arg_simple(a: Argument) -> Argument:
            return map_arg(a, lambda node: env[node.name])
        for node in self.quantized_graph.nodes:
            if node.op == 'output':
                    map_arg(node.args[0], load_arg_simple))
            if node.op == 'call_module' and \
                # remove activation post process node
                env[node.name] = env[node.args[0].name]
                env[node.name] = act_post_process_removed_graph.node_copy(
                    node, load_arg_simple)

        # removes qconfig and activation_post_process modules
        model = GraphModule(model, act_post_process_removed_graph)
        return model
Ejemplo n.º 2
def convert(model: GraphModule,
            is_reference: bool = False,
            convert_custom_config_dict: Dict[str, Any] = None,
            is_standalone_module: bool = False,
            _remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
    """ standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict = restore_state(
    qconfig_map: Dict[
        str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]
    # always run weight observers in the top level forward method
    # for dynamic quant ops or weight only quant ops

    # move to cpu since we only have quantized cpu kernels
    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
    matches = find_matches(model.graph,

    quantized_graph = Graph()
    env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {}

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':

    def load_non_quantized(n: Node) -> Node:
        assert n.name in env, \
            'trying to load float node but did not find ' + \
            'node:' + n.name + \
            ' in env: ' + \
        quantized_node, dtype = env[n.name]
        if dtype and dtype != torch.float:
            env[n.name] = Proxy(quantized_node).dequantize().node, torch.float
        return env[n.name][0]

    def load_quantized(n: Node) -> Node:
        assert n.name in env, \
            'trying to load quantized node but did not find node:' + \
            n.name + ' in environment:' + str(env)
        quantized_node, dtype = env[n.name]
        assert dtype in [torch.quint8, torch.qint8, torch.float16], \
            f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}'
        return quantized_node

    def load_x(n: Node) -> Node:
        assert n.name in env, \
            'node ' + n.name + ' does not exist in environment'
        return env[n.name][0]

    def load_arg(
        quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
    ) -> Callable[[Node], Argument]:
        Input: quantized, which can be None, list, boolean or tuple
          - if quantized is None, then we'll load the node as long as it
          - if quantized is a boolean, then all args will be
            quantized/not quantized
          - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False)
          - if quantized is a list or tuple, then arg should be a list and
            the args with corresponding indexes will be quantized

        Output: fn which takes arg_or_args, and loads them from the
            corresponding environment depending on the value of quantized.
        assert quantized is None or \
            isinstance(quantized, (tuple, list, bool)), type(quantized)
        if isinstance(quantized, (tuple, list)) and len(quantized) == 0:
            # empty tuple or list means nothing is quantized
            quantized = False

        def load_arg_impl(arg_or_args):
            # we'll update the format of `quantized`
            # to better match arg_or_args
            updated_quantized: Optional[Union[List[int], bool,
                                              Tuple[int, ...]]] = quantized

            if isinstance(quantized, (tuple, list)) and \
               len(quantized) == 1 and isinstance(arg_or_args, Node):
                # when argument is one Node instead of tuple, we just need to check
                # 0 is in the quantized list
                updated_quantized = 0 in quantized

            if updated_quantized is None:
                return map_arg(arg_or_args, load_x)
            if isinstance(updated_quantized, bool):
                return map_arg(
                    arg_or_args, load_quantized
                    if updated_quantized else load_non_quantized)
            elif isinstance(updated_quantized, (tuple, list)):
                assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                loaded_args = []
                # for now, we only support quantizing positional arguments
                for i, a in enumerate(arg_or_args):
                    if i in updated_quantized:
                        loaded_args.append(map_arg(a, load_quantized))
                        loaded_args.append(map_arg(a, load_non_quantized))
                return type(arg_or_args)(loaded_args)

        return load_arg_impl

    def node_arg_is_quantized(node_arg: Any) -> bool:
        if isinstance(node_arg, Node):
            assert node_arg.name in env, \
                'Expecting node_arg to be in the environment'
            if node_arg.name in env:
                _, dtype = env[node_arg.name]
                return dtype != torch.float
                return False
        elif isinstance(node_arg, list):
            quantized = map(node_arg_is_quantized, node_arg)
            if all(quantized):
                return True
            elif not any(quantized):
                return False
                raise Exception(
                    "partially quantized inputs in list not handled yet")
            return False

    def is_output_quantized(node: Node, obj: QuantizeHandler,
                            qconfig: QConfigAny,
                            modules: Dict[str, torch.nn.Module]) -> bool:
        """ Check if output node is quantized or not """
        assert modules is not None
        # by default the output for a quantizable node is expected to be quantized
        quantized = True

        # Need to get correct quantized/non-quantized state forn the output
        # of FixedQParamsQuantizeHandler
        # TODO: we may want to try to remove the special case here
        # as well
        if obj.should_mark_output_quantized_from_input_quantized_status(
            assert node.op in [
                'call_method'], \
                'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
            # TODO: need to extend this to consider all relevant args instead of just arg[0]
            quantized = node_arg_is_quantized(node.args[0])

        # the output is unquantized if the node is not a CopyNode
        # or the activation is not statically quantized
        if not activation_is_statically_quantized(qconfig) or \
           not obj.input_output_observed():
            quantized = False
        if node_return_type_is_int(node):
            quantized = False

        return quantized

    def insert_quantize_node(node: Node,
                             modules: Dict[str, torch.nn.Module]) -> None:
        """ Given a activation_post_process module call node, insert a
        quantize node"""
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        prev_node = node.args[0]
        if observer_module.dtype == torch.float32:
            # copy the observer for fp32 dtype
            env[node.name] = quantized_graph.node_copy(
                node, load_non_quantized), torch.float
        elif isinstance(prev_node, Node) and prev_node.name in env:
            # if previous node is already quantized, we'll just remove the
            # activation_post_process
            _, prev_dtype = env[prev_node.name]
            current_dtype = observer_module.dtype
            if prev_dtype == current_dtype:
                env[node.name] = env[prev_node.name]
                root_module = modules[""]
                assert isinstance(prev_node, Node)
                observer_dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
                env[node.name] = (quantize_node(load_non_quantized(prev_node),
                                                is_input=True), observer_dtype)
            # replace activation post process with quantization ops
            root_module = modules[""]
            assert isinstance(node.args[0], Node)
            dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
            env[node.name] = (quantize_node(load_non_quantized(node.args[0]),
                                            is_input=True), dtype)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    for node in model.graph.nodes:
        if node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                graph_output = map_arg(node.args[0], load_x)
                graph_output = map_arg(node.args[0], load_non_quantized)
        root_node, matched, matched_pattern, obj, qconfig = \
            matches.get(node.name, (None, None, None, None, None))
        if root_node is node:
            is_observed_standalone_module_node = (
                node.op == 'call_module'
                and is_observed_standalone_module(modules[node.target]))
            if qconfig is None and not is_observed_standalone_module_node:
                result = quantized_graph.node_copy(node, load_non_quantized)
                quantized = False
                assert obj is not None
                # We will get whether the output is quantized or not before
                # convert for standalone module and after convert
                # for non-standalone module, since _standalone_module_output_quantized_idxs
                # is only available in observed standalone module
                if is_observed_standalone_module_node:
                    out_quant_idxs = modules[
                        )  # type: ignore[operator] # noqa: B950
                    assert len(
                    ) <= 1, "Currently standalone only support one output"
                    quantized = 0 in out_quant_idxs

                qconfig = qconfig_map[node.name]
                result = obj.convert(
                if not is_observed_standalone_module_node:
                    quantized = is_output_quantized(node, obj, qconfig,

            if quantized:
                env[node.name] = result, activation_dtype(qconfig)
                env[node.name] = result, torch.float
        elif root_node is not None:
            if qconfig is None:
                # This branch is hit if all of these conditions are met:
                # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu)
                # 2. the current node is not the "root_node" of the pattern
                # 3. quantization for this pattern is disabled
                # In this case, we need to make sure to populate the env with
                # intermediate nodes manually, because the QuantizeHandler.convert
                # function will not be called.
                result = quantized_graph.node_copy(node, load_non_quantized)
                env[node.name] = result, torch.float

        # handle activation post process calls
        if node.op == 'call_module' and \
            insert_quantize_node(node, modules)
        elif node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                env[node.name] = \
                        node, load_non_quantized), torch.quint8
                env[node.name] = \
                    quantized_graph.node_copy(node, load_non_quantized), torch.float
            # copy quantized or non-quantized node
            # get_tensor_info_node like shape works for both
            # quantized and non-quantized input and output a non-Tensor
            # (we use None for dtype currently for non-Tensors)
            if is_get_tensor_info_node(node):
                env[node.name] = \
                    quantized_graph.node_copy(node, load_x), None
                env[node.name] = \
                    quantized_graph.node_copy(node, load_non_quantized), torch.float

    # remove activation post process
    act_post_process_removed_graph = Graph()
    remove_env: Dict[str, Node] = {}

    def load_arg_remove(a: Argument) -> Argument:
        return map_arg(a, lambda node: remove_env[node.name])

    for node in quantized_graph.nodes:
        if node.op == 'output':
                map_arg(node.args[0], load_arg_remove))
        if node.op == 'call_module' and \
            # remove activation post process node
            remove_env[node.name] = remove_env[node.args[0].name]
            remove_env[node.name] = act_post_process_removed_graph.node_copy(
                node, load_arg_remove)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
    preserved_attributes = set(
        convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, act_post_process_removed_graph,
    if not is_reference:
        model = fold_weight(model, node_name_to_scope)
    return model