Exemple #1
0
def update_qconfig_for_fusion(
    model: GraphModule,
    qconfig_dict: Any,
) -> Any:
    """
    Update the qconfig_dict to account for fused modules such as LinearReLU.
    """
    object_type_dict = qconfig_dict.get("object_type", None)
    if object_type_dict is None:
        return qconfig_dict

    modules = dict(model.named_modules())

    for node in model.graph.nodes:
        if node.op == 'call_module':
            module_type = type(modules[str(node.target)])
            if module_type not in list(DEFAULT_OP_LIST_TO_FUSER_METHOD.values()):
                continue

            for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items():
                if module_type == fuser:
                    fused_qconfig = object_type_dict.get(ops[0], None)

                    # Raise an error if the modules in the fused module have
                    # different qconfigs specified in the qconfig_dict
                    for op in ops:
                        if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig):
                            raise LookupError("During fusion, we need to specify the same " +
                                              f"qconfigs for both modules in {module_type}.")

                    if fused_qconfig is not None:
                        object_type_dict[module_type] = fused_qconfig

    return qconfig_dict
Exemple #2
0
def get_matching_activations_a_shadows_b(
    gm_a_shadows_b: GraphModule,
    logger_cls: Callable,
) -> Dict[str, Dict[str, List[torch.Tensor]]]:
    """
    Same thing as get_matching_activations, but for an `a_shadows_b` model.
    TODO(future PR): real docblock
    """
    results: Dict[str, Dict[str, List[torch.Tensor]]] = \
        collections.defaultdict(dict)
    for name, mod in gm_a_shadows_b.named_modules():
        # TODO(future PR): better check when scripted
        is_logger = (
            isinstance(mod, logger_cls)  # type: ignore
            or (
                isinstance(mod, torch.jit.RecursiveScriptModule)
                and mod.original_name == 'OutputLogger'
            )
        )
        if is_logger:
            # If logger_obj.other_node_name is populated, then this logger
            # is from model A, and other_node_name is the name from model B.
            if mod.other_node_name is None:
                results[mod.node_name + '.stats'][mod.model_name] = mod.stats
            else:
                results[mod.other_node_name + '.stats'][mod.model_name] = mod.stats
    return dict(results)
Exemple #3
0
    def _find_matches(
        self, root: GraphModule, graph: Graph, patterns: Dict[Pattern,
                                                              Callable]
    ) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]]:
        modules = dict(root.named_modules())
        match_map: Dict[str, Tuple[Node, Pattern, NodePattern,
                                   FuseHandler]] = {
                                   }  # node name -> (root_node, match_value)

        def apply_match(pattern, node, match, matched_node_pattern):
            if isinstance(pattern, tuple):
                s, *args = pattern
                current_node_pattern: List[Node] = []
                apply_match(s, node, match, current_node_pattern)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match, current_node_pattern)
                matched_node_pattern.append(tuple(current_node_pattern))
            else:
                # the first pattern matches will take precedence
                if node.name not in match_map:
                    matched_node_pattern.append(node)
                    root_node, pattern, handler = match
                    match_map[node.name] = (root_node, pattern,
                                            matched_node_pattern, handler)

        for node in reversed(graph.nodes):
            if node.name not in match_map:
                for pattern, value in patterns.items():
                    matched_node_pattern: List[Node] = []
                    if is_match(modules, node, pattern):
                        apply_match(pattern, node,
                                    (node, pattern, value(self, node)),
                                    matched_node_pattern)

        return match_map
Exemple #4
0
    def _find_matches(
            self, root: GraphModule, graph: Graph,
            patterns: Dict[Pattern, Callable]
    ) -> Dict[str, Tuple[Node, FuseHandler]]:
        modules = dict(root.named_modules())
        match_map : Dict[str, Tuple[Node, FuseHandler]] = {}  # node name -> (root_node, match_value)

        def apply_match(pattern, node, match):
            if isinstance(pattern, tuple):
                s, *args = pattern
                apply_match(s, node, match)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match)
            else:
                # the first pattern matches will take precedence
                if node.name not in match_map:
                    match_map[node.name] = match

        for node in reversed(graph.nodes):
            if node.name not in match_map:
                for pattern, value in patterns.items():
                    if is_match(modules, node, pattern):
                        apply_match(pattern, node, (node, value(self, node)))

        return match_map
Exemple #5
0
def _convert_equalization_ref(model: GraphModule):
    """ Reference function which applies changes needed for equalization, but
    does not quantize the nodes
    """
    modules = dict(model.named_modules(remove_duplicate=False))

    # Calculate the equalization scale, update the observers with the scaled
    # inputs, and scale the weight
    weight_eq_obs_dict = update_obs_for_equalization(model, modules)
    convert_eq_obs(model, modules, weight_eq_obs_dict)

    return GraphModule(model, model.graph)
def add_activation_info_to_dict(
    model_name: str,
    model: GraphModule,
    results: Dict[str, Dict[str, List[torch.Tensor]]],
    logger_cls: Callable,
) -> None:
    for gm_name, mod in model.named_modules():
        # TODO(future PR): better check when scripted
        is_logger = (
            isinstance(mod, logger_cls)  # type: ignore
            or (isinstance(mod, torch.jit.RecursiveScriptModule)
                and mod.original_name == 'OutputLogger'))
        if is_logger:
            key = mod.ref_node_name + '.stats'
            if key not in results:
                results[key] = {}
            results[key][model_name] = mod.stats
Exemple #7
0
def _find_matches(
    root: GraphModule, graph: Graph, patterns: Dict[Pattern, Callable]
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node,
                                                                   Any]]]:
    modules = dict(root.named_modules())
    # node name -> (root_node, match_value)
    match_map: Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler,
                               Dict[Node, Any]]] = {}
    # a map from node to the matched subpattern
    node_to_subpattern: Dict[Node, Any] = {}

    # TODO: dedup with quantization matching function in match_utils.py
    def apply_match(pattern, node, match, matched_node_pattern,
                    node_to_subpattern):
        if isinstance(pattern, tuple):
            s, *args = pattern
            current_node_pattern: List[Node] = []
            apply_match(s, node, match, current_node_pattern,
                        node_to_subpattern)
            for subpattern, arg in zip(args, node.args):
                apply_match(subpattern, arg, match, current_node_pattern,
                            node_to_subpattern)
            matched_node_pattern.append(tuple(current_node_pattern))
        else:
            # the first pattern matches will take precedence
            if node.name not in match_map:
                matched_node_pattern.append(node)
                # MatchAllNode here is actually MatchAllInputNode which should not
                # be added to match_map
                if pattern is not MatchAllNode:
                    node_to_subpattern[node] = pattern
                    root_node, pattern, handler = match
                    match_map[node.name] = (root_node, pattern,
                                            matched_node_pattern, handler,
                                            node_to_subpattern)

    for node in reversed(graph.nodes):
        if node.name not in match_map:
            for pattern, value in patterns.items():
                matched_node_pattern: List[Node] = []
                if is_match(modules, node, pattern):
                    apply_match(pattern, node, (node, pattern, value(node)),
                                matched_node_pattern, node_to_subpattern)
                    break

    return match_map
Exemple #8
0
def remove_observers_add_loggers(
    gm: GraphModule,
    node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, removes all observers, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if node.op == 'call_module' and is_activation_post_process(
                modules[node.target]):
            # remove activation post process node
            env[node.name] = env[node.args[0].name]

        elif node in node_to_instrument_to_ref_node_name:
            other_node_name = node_to_instrument_to_ref_node_name[node]
            # ensure env is populated with base node
            env[node.name] = new_graph.node_copy(node, load_arg)
            # add the logger after the base node
            env[node.name] = _insert_logger_after_node(env[node.name], gm,
                                                       logger_cls,
                                                       '_ns_logger_',
                                                       model_name,
                                                       other_node_name)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
Exemple #9
0
def update_qconfig_for_fusion(
    model: GraphModule,
    qconfig_dict: Any,
) -> Any:
    """
    Update the qconfig_dict to account for fused modules such as LinearReLU.
    """
    object_type_dict = qconfig_dict.get("object_type", None)
    if object_type_dict is None:
        return qconfig_dict

    modules = dict(model.named_modules())

    for node in model.graph.nodes:
        if node.op == 'call_module' and node.target in modules:
            maybe_fused_module = modules[str(node.target)]
            if not isinstance(maybe_fused_module, _FusedModule):
                continue

            ops = list(maybe_fused_module._modules.values())
            fused_qconfig = object_type_dict.get(type(ops[0]), None)

            # Raise an error if the modules in the fused module have
            # different qconfigs specified in the qconfig_dict
            # TODO: currently it only works for modules,
            # need to make this work for torch.nn.functional.relu
            # TODO: currently it only works for object_type configurations,
            # ideally it should work for different types of configurations,
            # maybe we want to redesign this part
            for op in ops[1:]:
                if not qconfig_equals(object_type_dict.get(type(op), None),
                                      fused_qconfig):
                    raise LookupError(
                        "During fusion, we need to specify the same " +
                        f"qconfigs for all module types in {type(maybe_fused_module)} "
                        + f"offending type: {type(op)}")

            if fused_qconfig is not None:
                object_type_dict[type(maybe_fused_module)] = fused_qconfig

    return qconfig_dict
Exemple #10
0
def add_activation_info_to_dict(
    model: GraphModule,
    results: NSResultsType,
    logger_cls: Callable,
) -> None:
    for gm_name, mod in model.named_modules():
        # TODO(future PR): better check when scripted
        is_logger = (
            isinstance(mod, logger_cls)  # type: ignore
            or (isinstance(mod, torch.jit.RecursiveScriptModule)
                and mod.original_name == 'OutputLogger'))
        if is_logger:
            key = mod.ref_name
            if key not in results:
                results[key] = {}
            results[key][mod.model_name] = {
                'type': NSSingleResultValuesType.NODE_OUTPUT.value,
                'values': mod.stats,
                'node_name': mod.node_name,
                'node_target_type': mod.node_target_type,
            }
def get_matching_activations_a_shadows_b(
    gm_a_shadows_b: GraphModule,
    logger_cls: Callable,
) -> NSResultsType:
    """
    Same thing as get_matching_activations, but for an `a_shadows_b` model.
    TODO(future PR): real docblock
    """
    results: NSResultsType = collections.defaultdict(dict)
    for name, mod in gm_a_shadows_b.named_modules():
        # TODO(future PR): better check when scripted
        is_logger = (
            isinstance(mod, logger_cls)  # type: ignore
            or (isinstance(mod, torch.jit.RecursiveScriptModule)
                and mod.original_name == 'OutputLogger'))
        if is_logger:
            results[mod.ref_name][mod.model_name] = {
                'type': NSSingleResultValuesType.NODE_OUTPUT.value,
                'values': mod.stats,
                'node_name': mod.node_name,
                'node_target_type': mod.node_target_type,
            }
    return dict(results)
Exemple #12
0
    def _convert(self, model: GraphModule, debug: bool = False,
                 convert_custom_config_dict: Dict[str, Any] = None,
                 is_standalone_module: bool = False) -> GraphModule:
        """ standalone_module means it a submodule that is not inlined in
        parent module, and will be quantized separately as one unit.

        Returns a quantized standalone module which accepts float input
        and produces float output.
        """
        if convert_custom_config_dict is None:
            convert_custom_config_dict = {}
        self.restore_state(model)
        # always run weight observers in the top level forward method
        # for dynamic quant ops or weight only quant ops
        self._run_weight_observers(model)

        # move to cpu since we only have quantized cpu kernels
        model.eval().cpu()
        self.modules = dict(model.named_modules())

        custom_module_classes = get_custom_module_class_keys(
            convert_custom_config_dict,
            "observed_to_quantized_custom_module_class")
        assert self.patterns is not None
        matches = self._find_matches(
            model.graph, self.modules, self.patterns,
            custom_module_classes=custom_module_classes)

        quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \
            self._find_quants(model.graph, matches)

        self.quantized_graph = Graph()
        env: Dict[str, Node] = {}
        quant_env: Dict[str, Node] = {}

        graph_inputs: List[str] = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        def load_non_quantized(n: Node) -> Node:
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find ' + \
                    'node:' + n.name + \
                    ' in quantized or non quantized environment, env: ' + \
                    str(env) + ' quant_env:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n: Node) -> Node:
            assert n.name in quant_env, \
                'trying to load quantized node but did not find node:' + \
                n.name + ' in quant environment:' + str(quant_env)
            return quant_env[n.name]

        def load_x(n: Node) -> Node:
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]]
                     ) -> Callable[[Node], Argument]:
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and
                the args with corresponding indexes will be quantized
              - if quantized is a boolean, then all args will be
                quantized/not quantized
              - if quantized is None, then we'll load the node as long as it
                exists

            Output: fn which takes arg_or_args, and loads them from the
                corresponding environment depending on the value of quantized.
            """
            assert quantized is None or \
                isinstance(quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        arg_or_args,
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)
            return load_arg_impl

        def node_arg_is_quantized(node_arg: Any) -> bool:
            if isinstance(node_arg, Node):
                assert node_arg.name in env or node_arg.name in quant_env, \
                    'Expecting node_arg to be in the environment'
                # there might be nodes appearing in both environemnts, but
                # quant_env will take precedence
                if node_arg.name in quant_env:
                    return True
                elif node_arg.name in env:
                    return False
                else:
                    return False
            elif isinstance(node_arg, list):
                quantized = map(node_arg_is_quantized, node_arg)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception(
                        "partially quantized inputs in list not handled yet")
            else:
                return False

        def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
            """ Check if output node is quantized or not """
            assert self.modules is not None
            # by default the output is expected to be quantized
            quantized = True

            # Need to get correct quantized/non-quantized state for the output
            # of CopyNode
            if type(obj) in [
                    CopyNode,
                    FixedQParamsOpQuantizeHandler
            ]:
                assert node.op in [
                    'call_module',
                    'call_function',
                    'call_method'], \
                    'CopyNode of type ' + node.op + ' is not handled'
                quantized = node_arg_is_quantized(node.args[0])

            if not activation_is_statically_quantized(qconfig) or \
               not input_output_observed(obj):
                quantized = False

            return quantized

        def insert_quantize_node(node: Node) -> None:
            """ Given a activation_post_process module call node, insert a
            quantize node"""
            assert self.modules is not None
            assert isinstance(node.target, str)
            observer_module = self.modules[node.target]
            prev_node = node.args[0]
            if observer_module.dtype == torch.float16:
                # activations are not quantized for
                # fp16 dynamic quantization
                # copy the activaiton_post_process node here
                # since we may need it when we insert prepack
                # op for weight of linear, this will be removed
                # later in a separate pass
                env[node.name] = self.quantized_graph.node_copy(
                    node, load_non_quantized)
            elif isinstance(prev_node, Node) and prev_node.name in quant_env:
                # if previous node is already quantized, we'll just remove the
                # activation_post_process
                quant_env[node.name] = quant_env[prev_node.name]
            else:
                # replace activation post process with quantization ops
                root_module = self.modules[""]
                assert isinstance(node.args[0], Node)
                quant_env[node.name] = quantize_node(
                    root_module, self.quantized_graph,
                    load_non_quantized(node.args[0]), observer_module)

        # additional state to override inputs to be quantized, if specified
        # by the user
        placeholder_node_seen_cnt = 0
        output_node_seen_cnt = 0
        input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "input_quantized_idxs", [])
        output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "output_quantized_idxs", [])

        for node in model.graph.nodes:
            if node.op == 'output':
                cur_output_node_idx = output_node_seen_cnt
                output_node_seen_cnt += 1
                if cur_output_node_idx in output_quantized_idxs:
                    # Result are kept quantized if the user specified the
                    # output_quantized_idxs override.
                    graph_output = map_arg(node.args[0], load_x)
                else:
                    graph_output = map_arg(node.args[0], load_non_quantized)
                self.quantized_graph.output(graph_output)
                continue
            root_node, matched, matched_pattern, obj, qconfig = \
                matches.get(node.name, (None, None, None, None, None))
            if root_node is node:
                if qconfig is None:
                    result = self.quantized_graph.node_copy(
                        node, load_non_quantized)
                    quantized = False
                else:
                    assert obj is not None
                    is_standalone_module_node = (
                        node.op == 'call_module' and
                        is_observed_standalone_module(
                            self.modules[node.target])  # type: ignore
                    )
                    result = obj.convert(
                        self, node, load_arg, debug=debug,
                        convert_custom_config_dict=convert_custom_config_dict)
                    if is_standalone_module_node:
                        quantized = False
                    else:
                        quantized = is_output_quantized(node, obj)

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module' and \
                    is_activation_post_process(self.modules[node.target]):
                insert_quantize_node(node)
            elif node.op == 'placeholder':
                cur_placeholder_node_idx = placeholder_node_seen_cnt
                placeholder_node_seen_cnt += 1
                if cur_placeholder_node_idx in input_quantized_idxs:
                    quant_env[node.name] = \
                        self.quantized_graph.node_copy(node, load_non_quantized)
                else:
                    env[node.name] = \
                        self.quantized_graph.node_copy(node, load_non_quantized)
            else:
                # copy quantized or non-quantized node
                env[node.name] = \
                    self.quantized_graph.node_copy(node, load_non_quantized)

        # remove activation post process
        act_post_process_removed_graph = Graph()
        env = {}

        def load_arg_simple(a: Argument) -> Argument:
            return map_arg(a, lambda node: env[node.name])
        for node in self.quantized_graph.nodes:
            if node.op == 'output':
                act_post_process_removed_graph.output(
                    map_arg(node.args[0], load_arg_simple))
                continue
            if node.op == 'call_module' and \
               is_activation_post_process(self.modules[node.target]):
                # remove activation post process node
                env[node.name] = env[node.args[0].name]
            else:
                env[node.name] = act_post_process_removed_graph.node_copy(
                    node, load_arg_simple)

        # removes qconfig and activation_post_process modules
        _remove_qconfig(model)
        model = GraphModule(model, act_post_process_removed_graph)
        return model
Exemple #13
0
def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node,
                                                                     Node]]],
    logger_cls: Callable,
    should_log_inputs: bool,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    node_b_to_matched_subgraph_a_and_name = {}
    for match_name, match in matched_subgraph_pairs.items():
        (node_start_a, node_end_a), (node_start_b, node_end_b) = match
        assert node_start_b is node_end_b, \
            "Shadowing subgraphs of B with multiple nodes is not yet handled."
        node_b_to_matched_subgraph_a_and_name[node_end_b] = \
            ((node_start_a, node_end_a), match_name)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        if node_b.op == 'call_module' and is_activation_post_process(
                modules[node_b.target]):
            # remove activation post process node
            env_c[node_b.name] = env_c[node_b.args[0].name]  # type: ignore

        elif node_b in node_b_to_matched_subgraph_a_and_name:
            (node_start_a, node_end_a), ref_name = \
                node_b_to_matched_subgraph_a_and_name[node_b]
            if False:
                print('b')
                print_node(node_b)
                print('a')
                print_node(node_start_a)
                print_node(node_end_a)

            # if necessary, log the input of node_c
            if should_log_inputs:
                if isinstance(node_b.args[0], Node):
                    prev_node_c = env_c[node_b.args[0].name]
                    env_c[prev_node_c.name] = _insert_logger_after_node(
                        prev_node_c,
                        gm_b,
                        logger_cls,
                        '_ns_logger_b_inp_',
                        node_b.name,
                        name_b,
                        ref_name,
                        NSSingleResultValuesType.NODE_INPUT.value,
                        index_within_arg=0)
                elif isinstance(node_b.args[0], list):
                    # first, save the prev_node instances, because they
                    # will be overwritten in the env after the first logger
                    # is added
                    prev_node_c_list = [
                        env_c[arg.name] for arg in node_b.args[0]
                    ]

                    for arg_idx, arg in enumerate(node_b.args[0]):
                        prev_node_c = prev_node_c_list[arg_idx]
                        env_c[prev_node_c.name] = _insert_logger_after_node(
                            prev_node_c,
                            gm_b,
                            logger_cls,
                            '_ns_logger_b_inp_',
                            node_b.name,
                            name_b,
                            ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=arg_idx)
                else:
                    # logging of inputs which are not lists is not supported yet
                    raise AssertionError(
                        f"type {type(node_b.args[0])} is not handled yet")
            # subgraph so far:
            #
            # (prev_node_c)+ -> (logger_c_input)?

            # ensure env_c is populated with base node
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
            node_c = env_c[node_b.name]

            # after this point,
            #
            # node_a is the original node from graph_a, with parent module gm_a
            # node_b is the original node from graph_b, with parent module gm_b
            # node_c is the copy of node_b in graph_c
            #
            # subgraph so far:
            #
            # (prev_node_c)+ -> (logger_c_input)? -> node_c

            # cast dtype from the dtype of node_c's input to the dtype of
            # node_a's input (dequant, etc)
            dtype_cast_node = _insert_dtype_cast_after_node(
                node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c,
                node_b.name + '_dtype_cast_')
            # note: not inserting to env_c because all nodes which use the dtype
            #   casts are copied from graph_a
            #
            # subgraph so far:
            #
            #           (dtype_cast_node)+
            #                  /
            # (prev_node_c)+ -> (logger_c_input)? -> node_c

            # if input logging is enabled, log the input to the subgraph
            if should_log_inputs:
                # TODO: explain this
                ref_node_name = ''
                if isinstance(dtype_cast_node, Node):
                    dtype_cast_node = _insert_logger_after_node(
                        dtype_cast_node,
                        gm_b,
                        logger_cls,
                        '_ns_logger_a_inp_',
                        ref_node_name,
                        name_a,
                        ref_name,
                        NSSingleResultValuesType.NODE_INPUT.value,
                        index_within_arg=0)
                    input_logger: Union[Node, List[Node]] = dtype_cast_node
                else:
                    assert isinstance(dtype_cast_node, list)
                    new_loggers = []
                    for dtype_cast_idx, dtype_cast_node_inner in enumerate(
                            dtype_cast_node):
                        dtype_cast_logger = _insert_logger_after_node(
                            dtype_cast_node_inner,
                            gm_b,
                            logger_cls,
                            '_ns_logger_a_inp_',
                            ref_node_name,
                            name_a,
                            ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=dtype_cast_idx)
                        new_loggers.append(dtype_cast_logger)
                    dtype_cast_node = new_loggers
                    input_logger = dtype_cast_node
                # subgraph so far:
                #
                #       (dtype_cast_node)+ -> (logger_a_input)?
                #                  /
                # prev_node_c -> (logger_c_input)? -> node_c

            # hook up the new mod_a copy to be in the graph, receiving the
            # same inputs as mod_b does, with dtype cast to match a
            node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                dtype_cast_node, node_start_a, node_end_a, gm_a, gm_b,
                node_c.name + '_shadow_copy_')
            env_c[node_a_shadows_c.name] = node_a_shadows_c
            # subgraph so far:
            #
            #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
            #                  /
            # (prev_node_c)+ -> (logger_c_input)? -> node_c

            if should_log_inputs:
                # When we created the input logger, we left the ref_node_name
                # as an empty string, because the subgraph copy did not exist
                # yet. Now that the subgraph copy exists, we modify this name
                # to its true value.
                # Note: the alternative to this is to create the input logger
                # after creating the subgraph, which is slightly more
                # complicated. This is the lesser of two evils.
                # input_logger = env_c[dtype_cast_node.name]
                # Find the first node in the subgraph
                cur_node = node_a_shadows_c
                while cur_node.args[0] != input_logger:
                    cur_node = cur_node.args[0]  # type: ignore
                if isinstance(input_logger, Node):
                    input_logger_mod = getattr(gm_b, input_logger.name)
                    input_logger_mod.ref_node_name = cur_node.name
                else:
                    assert isinstance(input_logger, list)
                    for input_logger_inner in input_logger:
                        input_logger_mod = getattr(gm_b,
                                                   input_logger_inner.name)
                        input_logger_mod.ref_node_name = cur_node.name

            # hook up a logger to the mod_b copy
            env_c[node_b.name] = _insert_logger_after_node(
                env_c[node_b.name],
                gm_b,
                logger_cls,
                '_ns_logger_b_',
                node_b.name,
                name_b,
                ref_name,
                NSSingleResultValuesType.NODE_OUTPUT.value,
                index_within_arg=0)
            # subgraph so far:
            #
            #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy
            #                  /
            # (prev_node_c+) -> (logger_c_input)? -> node_c -> logger_c

            # hook up a logger to the mod_a copy
            # Note: we pass node_b.name to this logger, for easy matching later
            env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                env_c[node_a_shadows_c.name],
                gm_b,
                logger_cls,
                '_ns_logger_a_',
                node_a_shadows_c.name,
                name_a,
                ref_name,
                NSSingleResultValuesType.NODE_OUTPUT.value,
                index_within_arg=0)
            # subgraph so far:
            #
            #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
            #                  /
            # (prev_node_c)+ -> (logger_c_input)? -> node_c -> logger_c

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c
Exemple #14
0
def convert(model: GraphModule,
            is_reference: bool = False,
            convert_custom_config_dict: Dict[str, Any] = None,
            is_standalone_module: bool = False,
            _remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
    """ standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict = restore_state(
        model)
    qconfig_map: Dict[
        str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]
    # always run weight observers in the top level forward method
    # for dynamic quant ops or weight only quant ops
    run_weight_observers(model)

    # move to cpu since we only have quantized cpu kernels
    model.eval().cpu()
    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    matches = find_matches(model.graph,
                           modules,
                           patterns,
                           qconfig_map,
                           custom_module_classes=custom_module_classes)

    quantized_graph = Graph()
    env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {}

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def load_non_quantized(n: Node) -> Node:
        assert n.name in env, \
            'trying to load float node but did not find ' + \
            'node:' + n.name + \
            ' in env: ' + \
            str(env)
        quantized_node, dtype = env[n.name]
        if dtype and dtype != torch.float:
            env[n.name] = Proxy(quantized_node).dequantize().node, torch.float
        return env[n.name][0]

    def load_quantized(n: Node) -> Node:
        assert n.name in env, \
            'trying to load quantized node but did not find node:' + \
            n.name + ' in environment:' + str(env)
        quantized_node, dtype = env[n.name]
        assert dtype in [torch.quint8, torch.qint8, torch.float16], \
            f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}'
        return quantized_node

    def load_x(n: Node) -> Node:
        assert n.name in env, \
            'node ' + n.name + ' does not exist in environment'
        return env[n.name][0]

    def load_arg(
        quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
    ) -> Callable[[Node], Argument]:
        """
        Input: quantized, which can be None, list, boolean or tuple
          - if quantized is None, then we'll load the node as long as it
            exists
          - if quantized is a boolean, then all args will be
            quantized/not quantized
          - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False)
          - if quantized is a list or tuple, then arg should be a list and
            the args with corresponding indexes will be quantized


        Output: fn which takes arg_or_args, and loads them from the
            corresponding environment depending on the value of quantized.
        """
        assert quantized is None or \
            isinstance(quantized, (tuple, list, bool)), type(quantized)
        if isinstance(quantized, (tuple, list)) and len(quantized) == 0:
            # empty tuple or list means nothing is quantized
            quantized = False

        def load_arg_impl(arg_or_args):
            # we'll update the format of `quantized`
            # to better match arg_or_args
            updated_quantized: Optional[Union[List[int], bool,
                                              Tuple[int, ...]]] = quantized

            if isinstance(quantized, (tuple, list)) and \
               len(quantized) == 1 and isinstance(arg_or_args, Node):
                # when argument is one Node instead of tuple, we just need to check
                # 0 is in the quantized list
                updated_quantized = 0 in quantized

            if updated_quantized is None:
                return map_arg(arg_or_args, load_x)
            if isinstance(updated_quantized, bool):
                return map_arg(
                    arg_or_args, load_quantized
                    if updated_quantized else load_non_quantized)
            elif isinstance(updated_quantized, (tuple, list)):
                assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                loaded_args = []
                # for now, we only support quantizing positional arguments
                for i, a in enumerate(arg_or_args):
                    if i in updated_quantized:
                        loaded_args.append(map_arg(a, load_quantized))
                    else:
                        loaded_args.append(map_arg(a, load_non_quantized))
                return type(arg_or_args)(loaded_args)

        return load_arg_impl

    def node_arg_is_quantized(node_arg: Any) -> bool:
        if isinstance(node_arg, Node):
            assert node_arg.name in env, \
                'Expecting node_arg to be in the environment'
            if node_arg.name in env:
                _, dtype = env[node_arg.name]
                return dtype != torch.float
            else:
                return False
        elif isinstance(node_arg, list):
            quantized = map(node_arg_is_quantized, node_arg)
            if all(quantized):
                return True
            elif not any(quantized):
                return False
            else:
                raise Exception(
                    "partially quantized inputs in list not handled yet")
        else:
            return False

    def is_output_quantized(node: Node, obj: QuantizeHandler,
                            qconfig: QConfigAny,
                            modules: Dict[str, torch.nn.Module]) -> bool:
        """ Check if output node is quantized or not """
        assert modules is not None
        # by default the output for a quantizable node is expected to be quantized
        quantized = True

        # Need to get correct quantized/non-quantized state forn the output
        # of FixedQParamsQuantizeHandler
        # TODO: we may want to try to remove the special case here
        # as well
        if obj.should_mark_output_quantized_from_input_quantized_status(
                qconfig):
            assert node.op in [
                'call_module',
                'call_function',
                'call_method'], \
                'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
            # TODO: need to extend this to consider all relevant args instead of just arg[0]
            quantized = node_arg_is_quantized(node.args[0])

        # the output is unquantized if the node is not a CopyNode
        # or the activation is not statically quantized
        if not activation_is_statically_quantized(qconfig) or \
           not obj.input_output_observed():
            quantized = False
        if node_return_type_is_int(node):
            quantized = False

        return quantized

    def insert_quantize_node(node: Node,
                             modules: Dict[str, torch.nn.Module]) -> None:
        """ Given a activation_post_process module call node, insert a
        quantize node"""
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        prev_node = node.args[0]
        if observer_module.dtype == torch.float32:
            # copy the observer for fp32 dtype
            env[node.name] = quantized_graph.node_copy(
                node, load_non_quantized), torch.float
        elif isinstance(prev_node, Node) and prev_node.name in env:
            # if previous node is already quantized, we'll just remove the
            # activation_post_process
            _, prev_dtype = env[prev_node.name]
            current_dtype = observer_module.dtype
            if prev_dtype == current_dtype:
                env[node.name] = env[prev_node.name]
            else:
                root_module = modules[""]
                assert isinstance(prev_node, Node)
                observer_dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
                env[node.name] = (quantize_node(load_non_quantized(prev_node),
                                                observer_module,
                                                node,
                                                modules,
                                                quantized_graph,
                                                node_name_to_scope,
                                                is_input=True), observer_dtype)
        else:
            # replace activation post process with quantization ops
            root_module = modules[""]
            assert isinstance(node.args[0], Node)
            dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
            env[node.name] = (quantize_node(load_non_quantized(node.args[0]),
                                            observer_module,
                                            node,
                                            modules,
                                            quantized_graph,
                                            node_name_to_scope,
                                            is_input=True), dtype)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    for node in model.graph.nodes:
        if node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                graph_output = map_arg(node.args[0], load_x)
            else:
                graph_output = map_arg(node.args[0], load_non_quantized)
            quantized_graph.output(graph_output)
            continue
        root_node, matched, matched_pattern, obj, qconfig = \
            matches.get(node.name, (None, None, None, None, None))
        if root_node is node:
            is_observed_standalone_module_node = (
                node.op == 'call_module'
                and is_observed_standalone_module(modules[node.target]))
            if qconfig is None and not is_observed_standalone_module_node:
                result = quantized_graph.node_copy(node, load_non_quantized)
                quantized = False
            else:
                assert obj is not None
                # We will get whether the output is quantized or not before
                # convert for standalone module and after convert
                # for non-standalone module, since _standalone_module_output_quantized_idxs
                # is only available in observed standalone module
                if is_observed_standalone_module_node:
                    out_quant_idxs = modules[
                        node.
                        target]._standalone_module_output_quantized_idxs.tolist(
                        )  # type: ignore[operator] # noqa: B950
                    assert len(
                        out_quant_idxs
                    ) <= 1, "Currently standalone only support one output"
                    quantized = 0 in out_quant_idxs

                qconfig = qconfig_map[node.name]
                result = obj.convert(
                    node,
                    qconfig,
                    modules,
                    quantized_graph,
                    node_name_to_scope,
                    load_arg,
                    is_reference=is_reference,
                    convert_custom_config_dict=convert_custom_config_dict)
                if not is_observed_standalone_module_node:
                    quantized = is_output_quantized(node, obj, qconfig,
                                                    modules)

            if quantized:
                env[node.name] = result, activation_dtype(qconfig)
            else:
                env[node.name] = result, torch.float
            continue
        elif root_node is not None:
            if qconfig is None:
                # This branch is hit if all of these conditions are met:
                # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu)
                # 2. the current node is not the "root_node" of the pattern
                # 3. quantization for this pattern is disabled
                #
                # In this case, we need to make sure to populate the env with
                # intermediate nodes manually, because the QuantizeHandler.convert
                # function will not be called.
                result = quantized_graph.node_copy(node, load_non_quantized)
                env[node.name] = result, torch.float
            continue

        # handle activation post process calls
        if node.op == 'call_module' and \
                is_activation_post_process(modules[node.target]):
            insert_quantize_node(node, modules)
        elif node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                env[node.name] = \
                    quantized_graph.node_copy(
                        node, load_non_quantized), torch.quint8
            else:
                env[node.name] = \
                    quantized_graph.node_copy(node, load_non_quantized), torch.float
        else:
            # copy quantized or non-quantized node
            # get_tensor_info_node like shape works for both
            # quantized and non-quantized input and output a non-Tensor
            # (we use None for dtype currently for non-Tensors)
            if is_get_tensor_info_node(node):
                env[node.name] = \
                    quantized_graph.node_copy(node, load_x), None
            else:
                env[node.name] = \
                    quantized_graph.node_copy(node, load_non_quantized), torch.float

    # remove activation post process
    act_post_process_removed_graph = Graph()
    remove_env: Dict[str, Node] = {}

    def load_arg_remove(a: Argument) -> Argument:
        return map_arg(a, lambda node: remove_env[node.name])

    for node in quantized_graph.nodes:
        if node.op == 'output':
            act_post_process_removed_graph.output(
                map_arg(node.args[0], load_arg_remove))
            continue
        if node.op == 'call_module' and \
           is_activation_post_process(modules[node.target]):
            # remove activation post process node
            remove_env[node.name] = remove_env[node.args[0].name]
        else:
            remove_env[node.name] = act_post_process_removed_graph.node_copy(
                node, load_arg_remove)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(
        convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, act_post_process_removed_graph,
                                 preserved_attributes)
    if not is_reference:
        model = fold_weight(model, node_name_to_scope)
    return model
Exemple #15
0
    def _prepare(self, model: GraphModule, qconfig_dict: Any,
                 prepare_custom_config_dict: Optional[Dict[str, Any]],
                 is_standalone_module: bool) -> GraphModule:
        """ standalone_module means it a submodule that is not inlined in
        parent module, and will be quantized separately as one unit.

        When we are preparing a standalone module:
        both input and output are observed in prepared standalone module
        Returns:
            model(GraphModule): prepared standalone module
        """
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}
        self.prepare_custom_config_dict = prepare_custom_config_dict

        additional_quant_patterns = \
            prepare_custom_config_dict.get("additional_quant_pattern", {})
        self.patterns = get_combined_dict(
            get_default_quant_patterns(), additional_quant_patterns)

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            additional_qat_module_mapping = prepare_custom_config_dict.get(
                "additional_qat_module_mapping", {})
            self._qat_swap_modules(model, additional_qat_module_mapping)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = prepare_custom_config_dict.get(
            "standalone_module_name", None)
        standalone_module_classes = prepare_custom_config_dict.get(
            "standalone_module_class", None)
        custom_module_classes = get_custom_module_class_keys(
            prepare_custom_config_dict, "float_to_observed_custom_module_class")
        assert self.patterns is not None
        matches = self._find_matches(
            model.graph, self.modules, self.patterns, standalone_module_names,
            standalone_module_classes, custom_module_classes)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuantizeHandler object for each
        quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \
            self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env: Dict[Any, Any] = {}
        observed_graph = Graph()
        observed_node_names_set: Set[str] = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs: List[int] = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix(
            'activation_post_process_')

        placeholder_node_seen_cnt = 0
        output_node_seen_cnt = 0
        input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "input_quantized_idxs", [])
        output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "output_quantized_idxs", [])

        result_node : Optional[Node] = None
        for node in model.graph.nodes:
            if node.op == 'output':
                # If this output is hardcoded to be quantized, insert an
                # observer on the previous node if it does not already
                # exist.
                cur_output_node_idx = output_node_seen_cnt
                output_node_seen_cnt += 1
                if cur_output_node_idx in output_quantized_idxs:
                    prev_node = node.args[0]
                    assert isinstance(prev_node, Node), \
                        ('hardcoding list/dict outputs to be quantized is ' +
                         'not supported')
                    if prev_node.name not in observed_node_names_set:
                        assert self.qconfig_map is not None
                        local_qconfig = self.qconfig_map[prev_node.name]
                        assert local_qconfig is not None, \
                            'qconfig of a node before a quantized output must exist'
                        insert_observer(
                            prev_node, local_qconfig.activation(),
                            model, self.activation_post_process_map,
                            env, observed_graph, load_arg, observed_node_names_set)

                observed_graph.output(load_arg(node.args[0]))
                result_node = node
                continue

            if node.name in observed_node_names_set:
                continue

            root_node, matched_nodes, pattern, obj, qconfig = matches.get(
                node.name, (None, None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                # index for input of custom module that needs to be observed in
                # parent
                if qconfig is not None:
                    assert obj is not None
                    insert_observer_for_special_module(
                        obj, self.modules, prepare_custom_config_dict, qconfig,
                        node)
                    insert_observer_for_output_of_the_node(
                        node, obj, qconfig, self.modules, model, pattern,
                        self.activation_post_process_map, env,
                        observed_graph, load_arg, observed_node_names_set,
                        matched_nodes)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.op == 'placeholder':
                # skip adding observers at the graph input if the input is
                # overriden to be quantized
                cur_placeholder_node_idx = placeholder_node_seen_cnt
                placeholder_node_seen_cnt += 1
                if cur_placeholder_node_idx in input_quantized_idxs:
                    observed_node_names_set.add(node.name)
                    continue

            insert_observer_for_input_arg_of_observed_node(
                node, observed_node_names_set, quants,
                model, self.activation_post_process_map, env,
                observed_graph, load_arg)


        model = GraphModule(model, observed_graph)
        self.save_state(model)
        model = mark_observed_module(model)
        return model
Exemple #16
0
    def _prepare(self, model: GraphModule, qconfig_dict: Any,
                 prepare_custom_config_dict: Optional[Dict[str, Any]],
                 is_standalone_module: bool) -> GraphModule:
        """ standalone_module means it a submodule that is not inlined in
        parent module, and will be quantized separately as one unit.

        When we are preparing a standalone module:
        both input and output are observed in prepared standalone module
        Returns:
            model(GraphModule): prepared standalone module
        """
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}

        additional_quant_patterns = \
            prepare_custom_config_dict.get("additional_quant_pattern", {})
        self.patterns = get_combined_dict(get_default_quant_patterns(),
                                          additional_quant_patterns)

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            additional_qat_module_mapping = prepare_custom_config_dict.get(
                "additional_qat_module_mapping", {})
            self._qat_swap_modules(model, additional_qat_module_mapping)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = prepare_custom_config_dict.get(
            "standalone_module_name", None)
        standalone_module_classes = prepare_custom_config_dict.get(
            "standalone_module_class", None)
        custom_module_classes = get_custom_module_class_keys(
            prepare_custom_config_dict,
            "float_to_observed_custom_module_class")
        assert self.patterns is not None
        matches = self._find_matches(model.graph, self.modules, self.patterns,
                                     standalone_module_names,
                                     standalone_module_classes,
                                     custom_module_classes)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuantizeHandler object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env: Dict[Any, Any] = {}
        observed_graph = Graph()
        observed_node_names_set: Set[str] = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs: List[int] = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix(
            'activation_post_process_')
        model_device = assert_and_get_unique_device(model)

        result_node: Optional[Node] = None
        for node in model.graph.nodes:
            if node.op == 'output':
                observed_graph.output(load_arg(node.args[0]))
                result_node = node
                continue
            if node.name in observed_node_names_set:
                continue

            root_node, matched_nodes, pattern, obj, qconfig = matches.get(
                node.name, (None, None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                # index for input of custom module that needs to be observed in
                # parent
                if qconfig is not None:
                    assert obj is not None
                    insert_observer_for_special_module(
                        obj, self.modules, prepare_custom_config_dict, qconfig,
                        node)
                    insert_observer_for_output_of_the_node(
                        node, obj, qconfig, self.modules, model, pattern,
                        model_device, self.activation_post_process_map, env,
                        observed_graph, load_arg, observed_node_names_set,
                        matched_nodes)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            insert_observer_for_input_arg_of_observed_node(
                node, observed_node_names_set, quants, model_device, model,
                self.activation_post_process_map, env, observed_graph,
                load_arg)

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        model = mark_observed_module(model)
        return model
Exemple #17
0
def insert_observers_for_model(
    model: GraphModule,
    modules: Dict[str, torch.nn.Module],
    matches: Dict[str, MatchResult],
    qconfig_map: Dict[str, QConfigAny],
    graph: Graph,
    prepare_custom_config_dict: Dict[str, Any],
    equalization_config_map: Dict[str, Any],
    input_quantized_idxs: List[int],
    output_quantized_idxs: List[int],
) -> Optional[Node]:
    """
    Inserts observers, using the following high level algorithm:

    For each node in the graph:
      1. determine the target dtype of this node in the quantized graph, and save
           it for future steps
      2. determine the target dtype or all args and kwargs of this node
      3. if any arg or kwarg's target dtype does not match the current node's
           dtype, insert an observer
      4. if the current node needs an output observer, insert it

    For example:

    - starting graph:
        x0 -> linear -> x1

    - observed graph after processing x0:
        x0(fp32)

    - observed graph after processing linear:
        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)

    - observed graph after processing x1:
        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1

    After a node is processed, the naive observer placement is guaranteed to be
    complete for that node and all of its predecessors. There can be future
    passes which optimize the graph by deduplicating observers, etc.
    """

    node_name_to_target_dtype: Dict[str, Any] = {}
    cache_for_no_tensor_check: Dict[Node, bool] = dict()

    inputs_seen_counter = 0
    outputs_seen_counter = 0
    results_node = None

    # first, populate the dtype map based only on qconfig and qhandler
    # this assumes:
    # graph inputs are fp32 by default, and int8 where overriden
    # other nodes output dtype is specified by the qconfig
    modules = dict(model.named_modules(remove_duplicate=False))
    for node in model.graph.nodes:
        root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
            node.name, (None, None, None, None, None))
        node_name_to_target_dtype[node.name] = get_target_activation_dtype_for_node(
            node, qconfig, inputs_seen_counter, outputs_seen_counter,
            input_quantized_idxs, output_quantized_idxs, qhandler,
            modules, cache_for_no_tensor_check)

    # Second, for nodes with known input dtypes, propagate them throughout the
    # graph. For example, if there is a call such as
    #   x1 = x0.masked_fill(mask, 1)
    # we propagate the type of mask to be torch.bool
    propagate_dtypes_for_known_nodes(
        model.graph, node_name_to_target_dtype, matches)

    # After this point, the current node and all of its arguments
    # have a dtype assigned. Now, we insert observers for inputs
    # of this node (if needed for this node), and the output of this node
    # (if needed for this node).

    # Since we are mutating the graph as we go, we iterate over the original
    # nodes before observer insertion, instead of model.graph.nodes.
    nodes_before_observation = list(model.graph.nodes)

    for node in nodes_before_observation:

        if node.op == 'placeholder':
            # if a graph input is in fp32, it does not need observation
            # if a graph input is in int8, we assume the observation happens
            #   outside of the graph, and no additional observation is needed
            pass

        elif node.op in ('call_module', 'call_method', 'call_function', 'output'):
            # check for matches
            root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
                node.name, (None, None, None, None, None))
            equalization_qconfig = equalization_config_map.get(node.name, None)

            this_node_dtype = node_name_to_target_dtype[node.name]
            output_not_a_tensor = this_node_dtype is None
            # TODO(future PR): consider stopping matching getitem
            is_getitem = node.op == 'call_function' and \
                node.target == operator.getitem

            skip_inserting_observers = (
                (qconfig is None) or
                output_not_a_tensor or
                is_getitem
            ) and (not node.op == 'output')

            if not skip_inserting_observers:
                modules = dict(model.named_modules(remove_duplicate=False))
                if node.op != 'output':

                    # This is currently only used for equalization.
                    # Checks if the current node is in a branch in which the two
                    # first layers are both being quantized.
                    #
                    # ex.       conv2
                    #         /
                    #      x -> conv1
                    #
                    # If this is the case, we will not apply equalization to the
                    # initial two layers.
                    is_quantized_branch = False
                    if (
                        len(node.args) > 0 and
                        isinstance(node.args[0], Node) and
                        len(node.args[0].users) > 1
                    ):
                        for user in node.args[0].users:
                            # Checks if there exists another user being quantized
                            is_user_quantized = (
                                qconfig_map.get(user.name, None) is not None or
                                (user.op == 'call_module' and isinstance(modules[str(user.target)], ObserverBase))
                            )
                            if user != node and is_user_quantized:
                                is_quantized_branch = True

                    # this modifies node inplace
                    maybe_insert_input_observers_for_node(
                        node, qconfig, model, modules, graph,
                        node_name_to_target_dtype,
                        qhandler, prepare_custom_config_dict)

                    # Insert equalization input observers if needed
                    maybe_insert_input_equalization_observers_for_node(
                        node, equalization_qconfig, model, modules, graph,
                        node_name_to_target_dtype, is_quantized_branch)

                    is_last_node_of_pattern = root_node is node
                    is_general_tensor_value_op = \
                        (qhandler is not None and qhandler.is_general_tensor_value_op())

                    is_general_tensor_shape_op = \
                        (qhandler is not None and qhandler.is_general_tensor_shape_op())

                    if is_last_node_of_pattern and not is_general_tensor_shape_op:
                        # this returns the new observer node if it was needed
                        maybe_output_obs_node = maybe_insert_output_observer_for_node(
                            node, model, modules, graph, matches,
                            node_name_to_target_dtype, pattern, qhandler)
                        if maybe_output_obs_node is not None:
                            # Update users of original node to use the output observer
                            # instead. For example, change
                            #
                            #           next_node
                            #          /
                            #   cur_node -> obs
                            #
                            # to
                            #
                            #                 next_node
                            #                 /
                            #   cur_node -> obs
                            #
                            # We need to save orig users before updating uses because
                            # the list of users will change as we update uses
                            orig_users = list(node.users.keys())
                            for user_node in orig_users:
                                if user_node is maybe_output_obs_node:
                                    continue
                                user_node.replace_input_with(node, maybe_output_obs_node)

                            # for general tensor value ops, we modify the graph
                            # to make all inputs and outputs use the first input's
                            # observer
                            if is_general_tensor_value_op:
                                if not maybe_make_input_output_share_observers(node, model, modules):
                                    remove_output_observer(node, model, modules)

                            if isinstance(qhandler, CustomModuleQuantizeHandler):
                                swap_custom_module_to_observed(node, qconfig, modules, prepare_custom_config_dict)

                else:  # output
                    maybe_insert_observers_before_graph_output(
                        node, output_quantized_idxs,
                        node_name_to_target_dtype, qconfig_map,
                        model, modules, graph)

        #
        # After this point, the current node has input and output observers
        # that it needs for itself inserted.
        #

        # increment the counters, so future inputs and outputs are assigned
        # correct dtypes
        if node.op == 'placeholder':
            inputs_seen_counter += 1
        elif node.op == 'output':
            outputs_seen_counter += 1
            results_node = node

    return results_node
def add_loggers_to_model(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if ((node in node_to_instrument_inputs_to_ref_node_name)
                or (node in node_to_instrument_outputs_to_ref_node_name)):
            fqn = _maybe_get_fqn(node, gm)

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[
                    node]
                # Ops such add and mul are special because either
                # one or two of the first two arguments can be tensors,
                # and if one argument is a tensor it can be first or
                # second (x + 1 versus 1 + x).
                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
                for node_arg_idx in arg_indices_to_log:
                    node_arg = node.args[node_arg_idx]
                    if type(node_arg) == Node:
                        # create a single input logger
                        prev_node = env[node_arg.name]
                        env[node_arg.name] = _insert_logger_after_node(
                            prev_node,
                            gm,
                            logger_cls,
                            '_ns_logger_',
                            node.name,
                            model_name,
                            ref_name,
                            ref_node_type,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0,
                            index_of_arg=node_arg_idx,
                            fqn=fqn)
                    elif type(
                            node_arg
                    ) == torch.fx.immutable_collections.immutable_list:
                        # create N input loggers, one for each node
                        for arg_idx, arg in enumerate(node_arg):
                            prev_node = env[arg.name]
                            env[prev_node.name] = _insert_logger_after_node(
                                prev_node,
                                gm,
                                logger_cls,
                                '_ns_logger_',
                                node.name,
                                model_name,
                                ref_name,
                                ref_node_type,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=arg_idx,
                                index_of_arg=node_arg_idx,
                                fqn=fqn)
                    else:
                        pass

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[
                    node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name],
                    gm,
                    logger_cls,
                    '_ns_logger_',
                    node.name,
                    model_name,
                    ref_name,
                    ref_node_type,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0,
                    index_of_arg=0,
                    fqn=fqn)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
Exemple #19
0
def prepare(
        model: GraphModule,
        qconfig_dict: Any,
        node_name_to_scope: Dict[str, Tuple[str, type]],
        prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
        equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
        backend_config_dict: Optional[Dict[str, Any]] = None,
        is_standalone_module: bool = False) -> ObservedGraphModule:
    """ standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    How the standalone module is observed is specified by `input_quantized_idxs` and
    `output_quantized_idxs` in the prepare_custom_config for the standalone module
    Args:
        node_name_to_scope: mapping from node name to the scope of the module which contains the node.
        The scope is a tuple of fully qualified path of the module and the type of the module
    Returns:
        model(GraphModule): prepared standalone module
        attributes:
            _standalone_module_input_quantized_idxs(List[Int]): a list of
                indexes for the graph input that is expected to be quantized,
                same as input_quantized_idxs configuration provided
                for the standalone module
            _standalone_module_output_quantized_idxs(List[Int]): a list of
                indexs for the graph output that is quantized
                same as input_quantized_idxs configuration provided
                for the standalone module
    """
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}
    if equalization_qconfig_dict is None:
        equalization_qconfig_dict = {}
    if backend_config_dict is None:
        backend_config_dict = get_fbgemm_backend_config_dict()

    validate_backend_config_dict(backend_config_dict)

    additional_quant_patterns = \
        prepare_custom_config_dict.get("additional_quant_pattern", {})
    # mapping from a tuple of nodes in reverse order to uninitialized
    #   QuantizeHandler subclass. For example,
    # {
    #   # match a single node
    #   (<class 'torch.nn.modules.conv.Conv3d'>:
    #     <class 'torch.quantization.fx.quantize.ConvRelu'>),
    #   # match multiple nodes in reverse order
    #   ((<function relu at 0x7f766a7360d0>, <built-in function add>):
    #     <class 'torch.quantization.fx.quantize.Add'>),
    # }
    quant_patterns = backend_config_dict["quant_patterns"]
    patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict(
        quant_patterns, additional_quant_patterns)

    convert_dict_to_ordered_dict(qconfig_dict)
    convert_dict_to_ordered_dict(equalization_qconfig_dict)
    flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
    # TODO: support regex as well
    propagate_qconfig_(model, flattened_qconfig_dict)

    if model.training:
        additional_qat_module_mapping = prepare_custom_config_dict.get(
            "additional_qat_module_mapping", {})
        qat_swap_modules(model, additional_qat_module_mapping)
        qconfig_dict = update_qconfig_for_qat(qconfig_dict, additional_qat_module_mapping)

    qconfig_dict = update_qconfig_for_fusion(model, qconfig_dict)
    equalization_qconfig_dict = update_qconfig_for_fusion(model, equalization_qconfig_dict)

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    modules = dict(model.named_modules())

    # fill qconfig_map, a map from node name to qconfig, used in find_matches
    equalization_qconfig_map = generate_qconfig_map(model, modules, model.graph, equalization_qconfig_dict, node_name_to_scope)
    qconfig_map = generate_qconfig_map(model, modules, model.graph, qconfig_dict, node_name_to_scope)

    # match the patterns that will get quantized
    standalone_module_name_configs = prepare_custom_config_dict.get(
        "standalone_module_name", [])
    standalone_module_class_configs = prepare_custom_config_dict.get(
        "standalone_module_class", [])

    standalone_module_names = [config[0] for config in standalone_module_name_configs]
    standalone_module_classes = [config[0] for config in standalone_module_class_configs]
    custom_module_classes = get_custom_module_class_keys(
        prepare_custom_config_dict, "float_to_observed_custom_module_class")
    matches = find_matches(
        model.graph, modules, patterns, qconfig_map, standalone_module_names,
        standalone_module_classes, custom_module_classes)

    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    run_prepare_fx_on_standalone_modules(
        model, modules, matches, prepare_custom_config_dict)

    result_node = insert_observers_for_model(
        model, modules, matches, qconfig_map,
        model.graph, prepare_custom_config_dict,
        equalization_qconfig_map,
        input_quantized_idxs, output_quantized_idxs)

    save_state(model, qconfig_map, node_name_to_scope, patterns,
               prepare_custom_config_dict, equalization_qconfig_map)
    preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", []))
    model = ObservedGraphModule(model, model.graph, preserved_attributes)
    if is_standalone_module:
        assert result_node is not None
        assert isinstance(result_node.args[0], Node), \
            "standalone module only supports returning simple value currently"\
            "(not tuple, dict etc.)"
        # these inputs are observed in parent
        # converting List[int] to Tensor since module attribute is
        # Union[Tensor, Module]
        model._standalone_module_input_quantized_idxs = \
            torch.tensor(input_quantized_idxs)
        model._standalone_module_output_quantized_idxs = torch.tensor(output_quantized_idxs)
    return model
def _convert_do_not_use(
        model: GraphModule,
        is_reference: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None,
        is_standalone_module: bool = False,
        _remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
    """
    We will convert an observed model (a module with observer calls) to a reference
    quantized model, the rule is simple:
    1. for each observer module call in the graph, we'll convert it to calls to
       quantize and dequantize functions based on the observer instance
    2. for weighted operations like linear/conv, we need to convert them to reference
       quantized module, this requires us to know whether the dtype configured for the
       weight is supported in the backend, this is done in prepare step and the result
       is stored in observed_node_names, we can decide whether we need to swap the
       module based on this set

    standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(
        model)
    qconfig_map: Dict[
        str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    assert is_reference, "_convert_do_not_use only supports reference option"

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    matches = find_matches(model.graph,
                           modules,
                           patterns,
                           qconfig_map,
                           custom_module_classes=custom_module_classes)

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def replace_observer_with_quantize_dequantize_node(
            graph: Graph, node: Node, modules: Dict[str,
                                                    torch.nn.Module]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        root_module = modules[""]
        if observer_module.dtype == torch.float32:
            # remove the node for now
            # TODO: support dynamic quant
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        elif observer_module.dtype in [
                torch.quint8, torch.qint8, torch.float16
        ]:
            node_type, quantize_op, qparams = get_quantize_node_info(
                observer_module)
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(
                            root_module, graph, key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op,
                                                   tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize",
                                                     args=(quantized_node, ))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                # Inputs are assumed to be quantized if the user specifid the
                # input_quantized_idxs override.
                # TODO: remove the quantize node for the placeholder
                raise Exception("input_quantized_idxs is not supported yet")
        elif node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                # TODO: remove dequantize node if any
                raise Exception("output_quantized_idxs is not supported yet")
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                replace_observer_with_quantize_dequantize_node(
                    model.graph, node, modules)
            elif type(modules[node.target]) in set(
                    WEIGHTED_MODULE_CLASSES).union(QAT_MODULE_CLASSES).union(
                        FUSED_MODULE_CLASSES):
                # TODO: refactor this part to a function
                original_module = modules[node.target]
                qconfig = original_module.qconfig

                is_observed = node.name in observed_node_names
                is_weight_quantized = weight_is_statically_quantized(qconfig)
                # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
                if qconfig is None or not is_observed or not is_weight_quantized:
                    continue

                float_module = original_module
                fused_module = None
                if isinstance(original_module, QAT_MODULE_CLASSES):
                    # case 1. converting qat module to
                    # a float module, we need to attch
                    # weight fake_quant to the module,
                    # weight fake_quant is assumed to be run during
                    # QAT so we don't need to run it again here
                    float_module = original_module.to_float(
                    )  # type: ignore[operator]
                    # change qat conv to conv
                    parent_name, name = _parent_name(node.target)
                    setattr(modules[parent_name], name, float_module)
                    if isinstance(float_module,
                                  torch.nn.intrinsic._FusedModule):
                        fused_module = float_module
                        float_module = fused_module[0]
                    weight_post_process = original_module.weight_fake_quant
                else:
                    # case 2. converting a float module/fused float module
                    # to float module, we need to attach
                    # weight observer to the conv module and run it
                    # with conv weight
                    if isinstance(original_module,
                                  torch.nn.intrinsic._FusedModule):
                        fused_module = original_module
                        float_module = fused_module[0]  # type: ignore[index]
                    assert qconfig is not None
                    weight_post_process = qconfig.weight()
                    # run weight observer
                    weight_post_process(
                        float_module.weight)  # type: ignore[operator]
                weight_qparams = get_qparam_dict(weight_post_process)
                ref_qmodule_cls = get_static_quant_module_class(
                    type(float_module), is_reference=True)
                ref_qmodule = ref_qmodule_cls.from_float(
                    float_module, weight_qparams)
                if fused_module is not None:
                    fused_module[0] = ref_qmodule
                else:
                    parent_name, name = _parent_name(node.target)
                    setattr(modules[parent_name], name, ref_qmodule)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(
        convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, model.graph, preserved_attributes)
    return model
def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
    logger_cls: Callable,
    should_log_inputs: bool,
    node_type_to_io_type_map: Optional[Dict[str,
                                            Set[NSNodeTargetType]]] = None,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    if node_type_to_io_type_map is None:
        node_type_to_io_type_map = get_node_type_to_io_type_map()

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    start_node_b_to_matched_subgraph_a_and_name = {}
    end_node_b_to_matched_subgraph_a_and_name = {}
    for match_name, match in matched_subgraph_pairs.items():
        subgraph_a, subgraph_b = match
        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
        start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
            (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
        end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
            (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        # calculate the flags to determine what to do with this node
        node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
        node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name

        if (node_b_is_start_node or node_b_is_end_node):

            if node_b_is_start_node:
                subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
                    start_node_b_to_matched_subgraph_a_and_name[node_b]
            else:
                assert node_b_is_end_node
                subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
                    end_node_b_to_matched_subgraph_a_and_name[node_b]

            # For both start_node and end_node verify that we know how to do
            # the dtype cast. If we do not, skip.
            node_input_type_a, node_output_type_a = \
                get_node_first_input_and_output_type(
                    subgraph_a.start_node, gm_a, logger_cls,
                    node_type_to_io_type_map)
            node_input_type_b, node_output_type_b = \
                get_node_first_input_and_output_type(
                    node_b, gm_b, logger_cls,
                    node_type_to_io_type_map)
            node_io_types_known_a_and_b = (
                node_input_type_a != NodeInputOrOutputType.UNKNOWN
                and node_output_type_a != NodeInputOrOutputType.UNKNOWN
                and node_input_type_b != NodeInputOrOutputType.UNKNOWN
                and node_output_type_b != NodeInputOrOutputType.UNKNOWN)
            if not node_io_types_known_a_and_b:
                print(
                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}'
                    +
                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}'
                    + ', unknown dtype cast')
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                continue

            # If we are shadowing from fp32 to int8, we need to insert
            # quantize_per_tensor call with qparams from the previous node.
            # Only do this if we are able to infer these qparams from the graph.
            if (node_input_type_a == NodeInputOrOutputType.INT8
                    and node_input_type_b == NodeInputOrOutputType.FP32):
                node_a_input_qparams = get_node_input_qparams(
                    subgraph_a.start_node, gm_a, node_type_to_io_type_map)
                if not node_a_input_qparams:
                    print(
                        f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}'
                        +
                        f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}'
                        + ', unknown input qparams')
                    env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                    continue

            fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
            fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)

            if node_b_is_start_node:

                # if necessary, log the input of node_c
                if should_log_inputs:
                    if isinstance(node_b.args[0], Node):
                        prev_node_c = env_c[node_b.args[0].name]
                        env_c[prev_node_c.name] = _insert_logger_after_node(
                            prev_node_c,
                            gm_b,
                            logger_cls,
                            '_ns_logger_b_inp_',
                            node_b.name,
                            name_b,
                            ref_name,
                            ref_node_type_b,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0,
                            index_of_arg=0,
                            fqn=fqn_base_b)
                    elif isinstance(node_b.args[0], list):
                        # first, save the prev_node instances, because they
                        # will be overwritten in the env after the first logger
                        # is added
                        prev_node_c_list = [
                            env_c[arg.name] for arg in node_b.args[0]
                        ]

                        for arg_idx, arg in enumerate(node_b.args[0]):
                            prev_node_c = prev_node_c_list[arg_idx]
                            env_c[
                                prev_node_c.name] = _insert_logger_after_node(
                                    prev_node_c,
                                    gm_b,
                                    logger_cls,
                                    '_ns_logger_b_inp_',
                                    node_b.name,
                                    name_b,
                                    ref_name,
                                    ref_node_type_b,
                                    NSSingleResultValuesType.NODE_INPUT.value,
                                    index_within_arg=arg_idx,
                                    index_of_arg=0,
                                    fqn=fqn_base_b)
                    else:
                        # logging of inputs which are not lists is not supported yet
                        raise AssertionError(
                            f"type {type(node_b.args[0])} is not handled yet")
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)?

            # Note: this if statement is always True, spelling it out to clarify code
            # intent.
            if node_b_is_start_node or node_b_is_end_node:
                # ensure env_c is populated with base node
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                node_c = env_c[node_b.name]

                # after this point,
                #
                # node_a is the original node from graph_a, with parent module gm_a
                # node_b is the original node from graph_b, with parent module gm_b
                # node_c is the copy of node_b in graph_c
                #
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_start_node:

                # cast dtype from the dtype of node_c's input to the dtype of
                # node_a's input (dequant, etc)
                prev_node_c = node_c.args[0]
                if should_log_inputs:
                    # skip the input logger when inserting a dtype cast
                    if isinstance(prev_node_c, Node):
                        prev_node_c = prev_node_c.args[0]
                    elif isinstance(prev_node_c, list):
                        prev_node_c = [arg.args[0] for arg in prev_node_c]
                dtype_cast_node = _insert_dtype_cast_after_node(
                    subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b,
                    graph_c, node_b.name + '_dtype_cast_', logger_cls,
                    node_type_to_io_type_map)
                # note: not inserting to env_c because all nodes which use the dtype
                #   casts are copied from graph_a
                #
                # subgraph so far:
                #
                #           (dtype_cast_node)+
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                # if input logging is enabled, log the input to the subgraph
                if should_log_inputs:
                    # TODO: explain this
                    ref_node_name = ''
                    if isinstance(dtype_cast_node, Node):
                        dtype_cast_node = _insert_logger_after_node(
                            dtype_cast_node,
                            gm_b,
                            logger_cls,
                            '_ns_logger_a_inp_',
                            ref_node_name,
                            name_a,
                            ref_name,
                            ref_node_type_a,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0,
                            index_of_arg=0,
                            fqn=fqn_base_a)
                        input_logger: Union[Node, List[Node]] = dtype_cast_node
                    else:
                        assert isinstance(dtype_cast_node, list)
                        new_loggers = []
                        for dtype_cast_idx, dtype_cast_node_inner in enumerate(
                                dtype_cast_node):
                            dtype_cast_logger = _insert_logger_after_node(
                                dtype_cast_node_inner,
                                gm_b,
                                logger_cls,
                                '_ns_logger_a_inp_',
                                ref_node_name,
                                name_a,
                                ref_name,
                                ref_node_type_a,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=dtype_cast_idx,
                                index_of_arg=0,
                                fqn=fqn_base_a)
                            new_loggers.append(dtype_cast_logger)
                        dtype_cast_node = new_loggers
                        input_logger = dtype_cast_node
                    # subgraph so far:
                    #
                    #       (dtype_cast_node)+ -> (logger_a_input)?
                    #                  /
                    # prev_node_c -> (logger_c_input)? -> node_start_c

                # hook up the new mod_a copy to be in the graph, receiving the
                # same inputs as mod_b does, with dtype cast to match a
                # Some ops, such as LSTMs, have two non-param inputs. If we have
                # such an op, pass the second param as well. Note: dtype casting
                # for the second param is not implemented yet, it can be added
                # later if there is a use case.
                node_c_second_non_param_arg = None
                num_non_param_args_node_a = get_number_of_non_param_args(
                    subgraph_a.start_node, gm_a)
                if num_non_param_args_node_a == 2:
                    node_c_second_non_param_arg = node_c.args[1]
                node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                    dtype_cast_node, node_c_second_non_param_arg, subgraph_a,
                    gm_a, gm_b, node_c.name + '_shadow_copy_')
                env_c[node_a_shadows_c.name] = node_a_shadows_c
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                if should_log_inputs:
                    # When we created the input logger, we left the ref_node_name
                    # as an empty string, because the subgraph copy did not exist
                    # yet. Now that the subgraph copy exists, we modify this name
                    # to its true value.
                    # Note: the alternative to this is to create the input logger
                    # after creating the subgraph, which is slightly more
                    # complicated. This is the lesser of two evils.
                    # input_logger = env_c[dtype_cast_node.name]
                    # Find the first node in the subgraph
                    cur_node = node_a_shadows_c
                    while cur_node.args[0] != input_logger:
                        cur_node = cur_node.args[0]  # type: ignore[assignment]
                    if isinstance(input_logger, Node):
                        input_logger_mod = getattr(gm_b, input_logger.name)
                        input_logger_mod.ref_node_name = cur_node.name
                    else:
                        assert isinstance(input_logger, list)
                        for input_logger_inner in input_logger:
                            input_logger_mod = getattr(gm_b,
                                                       input_logger_inner.name)
                            input_logger_mod.ref_node_name = cur_node.name

                # hook up a logger to the mod_a copy
                env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                    env_c[node_a_shadows_c.name],
                    gm_b,
                    logger_cls,
                    '_ns_logger_a_',
                    node_a_shadows_c.name,
                    name_a,
                    ref_name,
                    ref_node_type_a,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0,
                    index_of_arg=0,
                    fqn=fqn_base_a)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_end_node:

                # hook up a logger to the mod_b copy
                env_c[node_b.name] = _insert_logger_after_node(
                    env_c[node_b.name],
                    gm_b,
                    logger_cls,
                    '_ns_logger_b_',
                    node_b.name,
                    name_b,
                    ref_name,
                    ref_node_type_b,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0,
                    index_of_arg=0,
                    fqn=fqn_base_b)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
                #
                # Note: node_start_c may be the same node as node_end_c, or they
                # may have nodes inbetween.

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c
def remove_observers_add_loggers(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, str],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, str],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, removes all observers, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if node.op == 'call_module' and is_activation_post_process(
                modules[node.target]):
            # remove activation post process node
            env[node.name] = env[node.args[0].name]

        elif ((node in node_to_instrument_inputs_to_ref_node_name)
              or (node in node_to_instrument_outputs_to_ref_node_name)):

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name = node_to_instrument_inputs_to_ref_node_name[node]
                if type(node.args[0]) == Node:
                    # create a single input logger
                    prev_node = env[node.args[0].name]
                    env[node.args[0].name] = _insert_logger_after_node(
                        prev_node,
                        gm,
                        logger_cls,
                        '_ns_logger_',
                        node.name,
                        model_name,
                        ref_name,
                        NSSingleResultValuesType.NODE_INPUT.value,
                        index_within_arg=0)
                elif type(node.args[0]
                          ) == torch.fx.immutable_collections.immutable_list:
                    # create N input loggers, one for each node
                    for arg_idx, arg in enumerate(node.args[0]):
                        prev_node = env[arg.name]
                        env[prev_node.name] = _insert_logger_after_node(
                            prev_node,
                            gm,
                            logger_cls,
                            '_ns_logger_',
                            node.name,
                            model_name,
                            ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=arg_idx)
                else:
                    raise AssertionError(
                        f"type {type(node.args[0])} is not handled yet")

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name = node_to_instrument_outputs_to_ref_node_name[node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name],
                    gm,
                    logger_cls,
                    '_ns_logger_',
                    node.name,
                    model_name,
                    ref_name,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
Exemple #23
0
def convert(
        model: GraphModule, is_reference: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None,
        is_standalone_module: bool = False,
        _remove_qconfig_flag: bool = True,
        convert_qconfig_dict: Dict[str, Any] = None,
        backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
    """
    We will convert an observed model (a module with observer calls) to a reference
    quantized model, the rule is simple:
    1. for each observer module call in the graph, we'll convert it to calls to
       quantize and dequantize functions based on the observer instance
    2. for weighted operations like linear/conv, we need to convert them to reference
       quantized module, this requires us to know whether the dtype configured for the
       weight is supported in the backend, this is done in prepare step and the result
       is stored in observed_node_names, we can decide whether we need to swap the
       module based on this set

    standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model)
    qconfig_map: Dict[str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    # TODO this should be removed now that gpu support for quantization is being supported.
    # however in practice, as of 7/22/2021, certain functions that get called by convert expect
    # only cpu arguments.
    # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda',
    # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend.
    if not is_reference:
        model.cpu()

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    # TODO refactor this code once we update the prepare logic to have additional information on
    # which graph nodes have been observed and share that with convert to decide which observers to ignore.
    if convert_qconfig_dict:
        prepare_qconfig_dict: Dict[str, Dict[Any, Any]] = model._qconfig_dict  # type: ignore[assignment]
        modules_copy = copy.deepcopy(modules)
        convert_dict_to_ordered_dict(convert_qconfig_dict)
        if model._is_qat:
            convert_qconfig_dict = update_qconfig_for_qat(convert_qconfig_dict, {})
        convert_qconfig_dict = update_qconfig_for_fusion(model, convert_qconfig_dict)

        compare_prepare_convert_qconfig_dict(prepare_qconfig_dict, convert_qconfig_dict)  # type: ignore[arg-type]
        convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, convert_qconfig_dict, node_name_to_scope)
        # check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map
        # or are set to None in the convert_qconfig_map.
        for k, v in qconfig_map.items():
            assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(k)
            if convert_qconfig_map[k] is not None:
                assert qconfig_equals(v, convert_qconfig_map[k]), 'Expected k {} to have the same value in prepare qconfig_dict \
                and convert qconfig_dict, found {} updated to {}.'.format(k, v, convert_qconfig_map[k])
        qconfig_map = convert_qconfig_map

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    # always run weight observers in the top level forward method
    # for dynamic quant ops or weight only quant ops
    run_weight_observers(model)

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    # TODO: move this outside of this function
    def replace_observer_with_quantize_dequantize_node(
            model: torch.nn.Module,
            graph: Graph,
            node: Node,
            modules: Dict[str, torch.nn.Module],
            node_name_to_scope: Dict[str, Tuple[str, type]],
            qconfig_map: Dict[str, QConfigAny]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map)
        observer_module = modules[node.target]
        maybe_quantize_node_info = get_quantize_node_info(observer_module)
        # Skip replacing observers to quant/dequant nodes if the qconfigs of all
        # consumers and producers of this observer are None
        skip_replacement = all([
            has_none_qconfig(n, qconfig_map) for n in
            list(node.args) + list(node.users.keys())])
        if skip_replacement or maybe_quantize_node_info is None:
            # didn't find correponding quantize op and info for the observer_module
            # so we just remove the observer
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        else:
            # otherwise, we can convert the observer moduel call to quantize/dequantize node
            node_type, quantize_op, qparams = maybe_quantize_node_info
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    # TODO: we can add the information of whether a value needs to
                    # be registered as an attribute in qparams dict itself
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)

    # this is a temporary hack for custom module, we may want to implement
    # this properly after the custom module class design is finalized
    def replace_observer_with_dequantize_node(node: Node, graph: Graph):
        call_custom_module_node = node.args[0]
        assert isinstance(call_custom_module_node, Node), \
            f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
        node.replace_all_uses_with(call_custom_module_node)
        graph.erase_node(node)
        insert_dequantize_node(call_custom_module_node, graph)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    if backend_config_dict is None:
        backend_config_dict = get_native_backend_config_dict()
    root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config_dict)
    # convert tuples so that it can work with isinstance(module, tuple_of_classes)
    root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
    qat_module_classes = get_qat_module_classes(backend_config_dict)
    fused_module_classes = get_fused_module_classes(backend_config_dict)
    statically_quantized_custom_module_nodes: Set[Node] = set()

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                # Inputs are assumed to be quantized if the user specifid the
                # input_quantized_idxs override.
                # we need to dequantize the inputs since all operators took
                # floating point inputs in reference quantized models
                insert_dequantize_node(node, model.graph)
        elif node.op == "output":
            # If the argument is empty we don't need to do anything
            if len(output_quantized_idxs) == 0:
                continue
            # Result are kept quantized if the user specified the
            # output_quantized_idxs override.
            # Remove the dequantize operator for the node in the end if any
            return_node = node
            output = node.args[0]
            # outputs can be Node, list, tuple, dict, other cases are not supported yet
            if isinstance(output, (list, tuple)):
                for idx in output_quantized_idxs:
                    maybe_recursive_remove_dequantize(output[idx], return_node, model.graph)
            elif isinstance(output, (Node, dict)):
                # we treat dict as a single argument currently, but it can be extended
                # to support {"key": dtype} after we change output_quantized_idxs to
                # dict
                if 0 in output_quantized_idxs:
                    maybe_recursive_remove_dequantize(output, return_node, model.graph)
            else:
                warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}")
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                observed_node = node.args[0]
                if observed_node in statically_quantized_custom_module_nodes:
                    replace_observer_with_dequantize_node(node, model.graph)
                else:
                    replace_observer_with_quantize_dequantize_node(
                        model, model.graph, node, modules, node_name_to_scope,
                        qconfig_map)
            elif is_observed_standalone_module(modules[node.target]):
                convert_standalone_module(
                    node, modules, model, is_reference, backend_config_dict)
            elif type(modules[node.target]) in set(
                    root_module_classes).union(qat_module_classes).union(fused_module_classes):
                # extra check for fused module classes to make sure they are fused module classes
                # of target modules
                if type(modules[node.target]) in fused_module_classes and \
                   type(modules[node.target][0]) not in root_module_classes:
                    continue
                convert_weighted_module(
                    node, modules, observed_node_names, qconfig_map, backend_config_dict)
            elif type(modules[node.target]) in custom_module_classes:
                convert_custom_module(
                    node, model.graph, modules, custom_module_class_mapping,
                    statically_quantized_custom_module_nodes)

    preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, copy.deepcopy(model.graph), preserved_attributes)

    # remove deadcode after converting observers to quant/dequant ops
    model.graph.eliminate_dead_code()
    model.recompile()

    # TODO: maybe move this to quantize_fx.py
    if not is_reference:
        model = duplicate_dequantize_node(model)
        model = duplicate_quantize_dynamic_node(model)
        model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope)
        model = remove_quant_dequant_pairs(model)
        model = remove_extra_dequantize(model)
    # TODO: this looks hacky, we want to check why we need this and see if we can
    # remove this
    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    return model
Exemple #24
0
def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node,
                                                                     Node]]],
    logger_cls: Callable,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    node_b_to_matched_subgraph_a = {}
    for match_name, match in matched_subgraph_pairs.items():
        (node_start_a, node_end_a), (node_start_b, node_end_b) = match
        assert node_start_b is node_end_b, \
            "Shadowing subgraphs of B with multiple nodes is not yet handled."
        node_b_to_matched_subgraph_a[node_end_b] = (node_start_a, node_end_a)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        if node_b.op == 'call_module' and is_activation_post_process(
                modules[node_b.target]):
            # remove activation post process node
            env_c[node_b.name] = env_c[node_b.args[0].name]  # type: ignore

        elif node_b in node_b_to_matched_subgraph_a:
            node_start_a, node_end_a = node_b_to_matched_subgraph_a[node_b]
            if False:
                print('b')
                print_node(node_b)
                print('a')
                print_node(node_start_a)
                print_node(node_end_a)

            # ensure env_c is populated with base node
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
            node_c = env_c[node_b.name]

            # after this point,
            #
            # node_a is the original node from graph_a, with parent module gm_a
            # node_b is the original node from graph_b, with parent module gm_b
            # node_c is the copy of node_b in graph_c
            #
            # subgraph so far:
            #
            # prev_node_c -> node_c

            # cast dtype from the dtype of node_c's input to the dtype of
            # node_a's input (dequant, etc)
            dtype_cast_node = _insert_dtype_cast_after_node(
                node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c,
                node_b.name + '_dtype_cast_')
            env_c[dtype_cast_node.name] = dtype_cast_node
            # subgraph so far:
            #
            #       dtype_cast_node
            #      /
            # prev_node_c -> node_c

            # hook up the new mod_a copy to be in the graph, receiving the
            # same inputs as mod_b does, with dtype cast to match a
            node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                env_c[dtype_cast_node.name], node_start_a, node_end_a, gm_a,
                gm_b, node_c.name + '_shadow_copy_')
            env_c[node_a_shadows_c.name] = node_a_shadows_c
            # subgraph so far:
            #
            #       dtype_cast_node --> subgraph_a_copy(args/kwargs not shown)
            #      /
            # prev_node_c -> node_c

            # hook up a logger to the mod_b copy
            env_c[node_b.name] = _insert_logger_after_node(
                env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', name_b)
            # subgraph so far:
            #
            #       dtype_cast_node --> subgraph_a_copy
            #      /
            # prev_node_c -> node_c --> logger_c

            # hook up a logger to the mod_a copy
            # Note: we pass node_b.name to this logger, for easy matching later
            env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                env_c[node_a_shadows_c.name], gm_b, logger_cls,
                '_ns_logger_a_', name_a, node_b.name)
            # subgraph so far:
            #
            #       dtype_cast_node --> subgraph_a_copy --> logger_a
            #      /
            # prev_node_c -> node_c --> logger_c

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c
Exemple #25
0
def convert(model: GraphModule,
            is_reference: bool = False,
            convert_custom_config_dict: Dict[str, Any] = None,
            is_standalone_module: bool = False,
            _remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
    """ standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict = restore_state(
        model)
    qconfig_map: Dict[
        str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    # TODO this should be removed now that gpu support for quantization is being supported.
    # however in practice, as of 7/22/2021, certain functions that get called by convert expect
    # only cpu arguments.
    # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda',
    # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend.
    if not is_reference:
        model.cpu()

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    matches = find_matches(model.graph,
                           modules,
                           patterns,
                           qconfig_map,
                           custom_module_classes=custom_module_classes)

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    # always run weight observers in the top level forward method
    # for dynamic quant ops or weight only quant ops
    run_weight_observers(model)

    quantized_graph = Graph()
    env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict(
        lambda: defaultdict(Node))  # type: ignore[arg-type]

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def load_non_quantized(n: Node) -> Node:
        assert n.name in env, \
            'trying to load float node but did not find ' + \
            'node:' + n.name + \
            ' in env: ' + \
            str(env)
        dtype_to_node = env[n.name]
        if torch.float in dtype_to_node:
            return dtype_to_node[torch.float]
        elif None in dtype_to_node:
            return dtype_to_node[None]
        else:
            quantized_node = None
            for dtype in [torch.quint8, torch.qint8, torch.float16]:
                if dtype in dtype_to_node:
                    quantized_node = dtype_to_node[dtype]
                    break
            assert quantized_node is not None, "Did not find a supported quantized dtype:{}".format(
                dtype_to_node)
            env[n.name][torch.float] = Proxy(quantized_node).dequantize().node
            return env[n.name][torch.float]

    def load_quantized(dtype: torch.dtype):
        def load_quantized_impl(n: Node):
            assert n.name in env, \
                'trying to load quantized node but did not find node:' + \
                n.name + ' in environment:' + str(env)
            dtype_to_node = env[n.name]
            local_dtype: Optional[torch.dtype] = dtype
            if local_dtype == torch.float and local_dtype not in dtype_to_node:
                local_dtype = None
            if local_dtype in [torch.float, None]:
                return load_non_quantized(n)
            assert local_dtype in dtype_to_node, f'Expecting {dtype} in {dtype_to_node}'
            return dtype_to_node[local_dtype]

        return load_quantized_impl

    def load_x(n: Node) -> Node:
        assert n.name in env, \
            'node ' + n.name + ' does not exist in environment'
        dtype_to_node = env[n.name]
        dtypes = [
            torch.quint8, torch.qint8, torch.float16, torch.float32, None
        ]
        for dtype in dtypes:
            if dtype in dtype_to_node:
                return dtype_to_node[dtype]
        raise Exception(
            f'dtype {dtype} not found in environment: {dtype_to_node} for node {n.name}'
        )

    def load_arg(
        quantized: Optional[Union[List[int], Dict[int, torch.dtype],
                                  torch.dtype, Tuple[int, ...]]]
    ) -> Callable[[Node], Argument]:
        """
        Input: quantized, which can be None, torch.dtype, list or tuple
          - if quantized is None, then we'll load the node as long as it
            exists
          - if quantized is a dtype, then all args will be
            quantized to the specific dtype
          - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=torch.float)
          - if quantized is a list or tuple, then arg should be a list and
            the args with corresponding indexes will be quantized to torch.quint8


        Output: fn which takes arg_or_args, and loads them from the
            corresponding environment depending on the value of quantized.
        """
        assert quantized is None or \
            isinstance(quantized, (tuple, list, dict, torch.dtype)), type(quantized)
        if isinstance(quantized, (tuple, list, dict)) and len(quantized) == 0:
            # empty tuple or list means nothing is quantized
            quantized = torch.float

        def load_arg_impl(arg_or_args):
            # we'll update the format of `quantized`
            # to better match arg_or_args
            updated_quantized: Optional[Union[List[int], torch.dtype,
                                              Dict[int, torch.dtype],
                                              Tuple[int, ...]]] = quantized

            if isinstance(quantized, (tuple, list)) and \
               len(quantized) == 1 and isinstance(arg_or_args, Node):
                # when argument is one Node instead of tuple, we just need to check
                # 0 is in the quantized list
                if 0 in quantized:
                    updated_quantized = torch.quint8

            if updated_quantized is None:
                return map_arg(arg_or_args, load_x)
            if isinstance(updated_quantized, torch.dtype):
                return map_arg(arg_or_args, load_quantized(updated_quantized))
            elif isinstance(updated_quantized, (tuple, list)):
                assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                loaded_args = []
                # for now, we only support quantizing positional arguments
                for i, a in enumerate(arg_or_args):
                    if i in updated_quantized:
                        # Currently it's hardcoded to torch.quint8, we can extend this
                        # in the future to support all quantized
                        # dtypes
                        loaded_args.append(
                            map_arg(a, load_quantized(torch.quint8)))
                    else:
                        loaded_args.append(map_arg(a, load_non_quantized))
                return type(arg_or_args)(loaded_args)
            elif isinstance(updated_quantized, dict):
                loaded_args = []
                for i, a in enumerate(arg_or_args):
                    if i in updated_quantized:
                        loaded_args.append(
                            map_arg(a, load_quantized(updated_quantized[i])))
                    else:
                        loaded_args.append(map_arg(a, load_non_quantized))
                return type(arg_or_args)(loaded_args)

        return load_arg_impl

    def node_arg_is_quantized(node_arg: Any) -> bool:
        if isinstance(node_arg, Node):
            assert node_arg.name in env, \
                'Expecting node_arg to be in the environment'
            if node_arg.name in env:
                dtype_to_node = env[node_arg.name]
                return any([
                    x in dtype_to_node
                    for x in [torch.quint8, torch.qint8, torch.float16]
                ])
            else:
                return False
        elif isinstance(node_arg, list):
            quantized = map(node_arg_is_quantized, node_arg)
            if all(quantized):
                return True
            elif not any(quantized):
                return False
            else:
                raise Exception(
                    "partially quantized inputs in list not handled yet")
        else:
            return False

    def is_output_quantized(node: Node, obj: QuantizeHandler,
                            qconfig: QConfigAny,
                            modules: Dict[str, torch.nn.Module]) -> bool:
        """ Check if output node is quantized or not """
        assert modules is not None
        # for some ops the output is quantized only when `is_reference` is True
        # and when `is_reference` is False, it has limited qconfig
        # support, for example `add`
        # ideally this check should not happen here, it should happen either in
        # prepare or during lowering, we don't need this check
        # after the default path is changed to produce reference patterns
        quantized = obj.is_output_quantized(qconfig)

        # Need to get correct quantized/non-quantized state forn the output
        # of FixedQParamsQuantizeHandler
        # TODO: we may want to try to remove the special case here
        # as well
        if obj.should_mark_output_quantized_from_input_quantized_status(
                qconfig):
            assert node.op in [
                'call_module',
                'call_function',
                'call_method'], \
                'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
            # TODO: need to extend this to consider all relevant args instead of just arg[0]
            quantized = node_arg_is_quantized(node.args[0])

        # the output is unquantized if the node is not a CopyNode
        # or the activation is not statically quantized
        if not activation_is_statically_quantized(qconfig) or \
           not obj.input_output_observed():
            quantized = False
        if node_return_type_is_int(node):
            quantized = False

        return quantized

    def insert_quantize_node(node: Node,
                             modules: Dict[str, torch.nn.Module]) -> None:
        """ Given a activation_post_process module call node, insert a
        quantize node"""
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        prev_node = node.args[0]
        if observer_module.dtype == torch.float32:
            # copy the observer for fp32 dtype
            env[node.name][torch.float] = quantized_graph.node_copy(
                node, load_non_quantized)
        elif isinstance(prev_node, Node) and prev_node.name in env:
            # if previous node is already quantized, we'll just remove the
            # activation_post_process
            prev_dtype_to_node: Dict[Optional[torch.dtype],
                                     Node] = env[prev_node.name]
            current_dtype: Optional[
                torch.
                dtype] = observer_module.dtype  # type: ignore[assignment]
            if current_dtype in prev_dtype_to_node:
                env[node.
                    name][current_dtype] = prev_dtype_to_node[current_dtype]
            else:
                root_module = modules[""]
                assert isinstance(prev_node, Node)
                observer_dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
                env[node.name][observer_dtype] = \
                    quantize_node(
                        load_non_quantized(prev_node),
                        observer_module, node, modules, quantized_graph,
                        node_name_to_scope, is_input=True)
        else:
            # replace activation post process with quantization ops
            root_module = modules[""]
            assert isinstance(node.args[0], Node)
            dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
            env[node.name][dtype] = \
                quantize_node(
                    load_non_quantized(node.args[0]),
                    observer_module, node, modules,
                    quantized_graph,
                    node_name_to_scope, is_input=True)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    for node in model.graph.nodes:
        if node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                graph_output = map_arg(node.args[0], load_x)
            else:
                graph_output = map_arg(node.args[0], load_non_quantized)
            quantized_graph.output(graph_output)
            continue
        root_node, matched, matched_pattern, obj, qconfig = \
            matches.get(node.name, (None, None, None, None, None))
        if root_node is node:
            is_observed_standalone_module_node = (
                node.op == 'call_module'
                and is_observed_standalone_module(modules[node.target]))
            if qconfig is None and not is_observed_standalone_module_node:
                result = quantized_graph.node_copy(node, load_non_quantized)
                quantized = False
            else:
                assert obj is not None
                # We will get whether the output is quantized or not before
                # convert for standalone module and after convert
                # for non-standalone module, since _standalone_module_output_quantized_idxs
                # is only available in observed standalone module
                if is_observed_standalone_module_node:
                    out_quant_idxs = modules[
                        node.
                        target]._standalone_module_output_quantized_idxs.tolist(
                        )  # noqa: B950
                    assert len(
                        out_quant_idxs
                    ) <= 1, "Currently standalone only support one output"
                    quantized = 0 in out_quant_idxs

                qconfig = qconfig_map[node.name]
                # Note: load_arg can be overwritten in the convert method when used to
                # create Node in graph
                result = obj.convert(
                    node,
                    qconfig,
                    modules,
                    quantized_graph,
                    node_name_to_scope,
                    load_arg,
                    is_reference=is_reference,
                    convert_custom_config_dict=convert_custom_config_dict)
                if not is_observed_standalone_module_node:
                    quantized = is_output_quantized(node, obj, qconfig,
                                                    modules)

            if quantized:
                env[node.name][activation_dtype(qconfig)] = result
            else:
                env[node.name][torch.float] = result
            continue
        elif root_node is not None:
            if qconfig is None:
                # This branch is hit if all of these conditions are met:
                # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu)
                # 2. the current node is not the "root_node" of the pattern
                # 3. quantization for this pattern is disabled
                #
                # In this case, we need to make sure to populate the env with
                # intermediate nodes manually, because the QuantizeHandler.convert
                # function will not be called.
                result = quantized_graph.node_copy(node, load_non_quantized)
                env[node.name][torch.float] = result
            continue

        # handle activation post process calls
        if node.op == 'call_module' and \
                is_activation_post_process(modules[node.target]):
            insert_quantize_node(node, modules)
        elif node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                env[node.name][torch.quint8] = quantized_graph.node_copy(
                    node, load_non_quantized)
            else:
                env[node.name][torch.float] = \
                    quantized_graph.node_copy(node, load_non_quantized)
        else:
            # copy quantized or non-quantized node
            # get_tensor_info_node like shape works for both
            # quantized and non-quantized input and output a non-Tensor
            # (we use None for dtype currently for non-Tensors)
            if is_get_tensor_info_node(node):
                env[node.name][None] = \
                    quantized_graph.node_copy(node, load_x)
            else:
                env[node.name][torch.float] = \
                    quantized_graph.node_copy(node, load_non_quantized)

    # remove activation post process
    act_post_process_removed_graph = Graph()
    remove_env: Dict[str, Node] = {}

    def load_arg_remove(a: Argument) -> Argument:
        return map_arg(a, lambda node: remove_env[node.name])

    for node in quantized_graph.nodes:
        if node.op == 'output':
            act_post_process_removed_graph.output(
                map_arg(node.args[0], load_arg_remove))
            continue
        if node.op == 'call_module' and \
           is_activation_post_process(modules[node.target]):
            # remove activation post process node
            remove_env[node.name] = remove_env[node.args[0].name]
        else:
            remove_env[node.name] = act_post_process_removed_graph.node_copy(
                node, load_arg_remove)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(
        convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, act_post_process_removed_graph,
                                 preserved_attributes)
    if not is_reference:
        model = fold_weight(model, node_name_to_scope)
        model = lower_to_fbgemm(model)
    return model
Exemple #26
0
def _convert_do_not_use(
        model: GraphModule, is_reference: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None,
        is_standalone_module: bool = False,
        _remove_qconfig_flag: bool = True,
        backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
    """
    We will convert an observed model (a module with observer calls) to a reference
    quantized model, the rule is simple:
    1. for each observer module call in the graph, we'll convert it to calls to
       quantize and dequantize functions based on the observer instance
    2. for weighted operations like linear/conv, we need to convert them to reference
       quantized module, this requires us to know whether the dtype configured for the
       weight is supported in the backend, this is done in prepare step and the result
       is stored in observed_node_names, we can decide whether we need to swap the
       module based on this set

    standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model)
    qconfig_map: Dict[str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    assert is_reference, "_convert_do_not_use only supports reference option"

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def replace_observer_with_quantize_dequantize_node(graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        root_module = modules[""]
        if observer_module.dtype == torch.float32:
            # remove the node for now
            # TODO: support dynamic quant
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        elif observer_module.dtype in [torch.quint8, torch.qint8, torch.float16]:
            node_type, quantize_op, qparams = get_quantize_node_info(observer_module)
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(root_module, graph, key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)


    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    if backend_config_dict is None:
        backend_config_dict = {}
    quantized_reference_module_mapping = get_quantized_reference_module_mapping(backend_config_dict)
    # convert tuples so that it can work with isinstance(module, tuple_of_classes)
    weighted_module_classes = tuple(quantized_reference_module_mapping.keys())

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                # Inputs are assumed to be quantized if the user specifid the
                # input_quantized_idxs override.
                # we need to dequantize the inputs since all operators took
                # floating point inputs in reference quantized models
                insert_dequantize_node(node, model.graph)
        elif node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                # Remove the dequantize operator in the end
                maybe_dequantize_node = node.args[0]
                if isinstance(maybe_dequantize_node, Node) and \
                   maybe_dequantize_node.op == "call_method" and \
                   maybe_dequantize_node.target == "dequantize":
                    quantize_node = maybe_dequantize_node.args[0]
                    maybe_dequantize_node.replace_all_uses_with(quantize_node)
                    model.graph.erase_node(maybe_dequantize_node)
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                replace_observer_with_quantize_dequantize_node(model.graph, node, modules)
            elif is_observed_standalone_module(modules[node.target]):
                # TODO: move this to a separate function
                convert_standalone_module(node, modules, model, is_reference, backend_config_dict)

            elif type(modules[node.target]) in set(
                    weighted_module_classes).union(QAT_MODULE_CLASSES).union(FUSED_MODULE_CLASSES):
                convert_weighted_module(node, modules, observed_node_names, quantized_reference_module_mapping)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, model.graph, preserved_attributes)
    return model