示例#1
0
def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[Tuple[Node, Node], Tuple[Node,
                                                                     Node]]],
    logger_cls: Callable,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    node_b_to_matched_subgraph_a = {}
    for match_name, match in matched_subgraph_pairs.items():
        (node_start_a, node_end_a), (node_start_b, node_end_b) = match
        assert node_start_b is node_end_b, \
            "Shadowing subgraphs of B with multiple nodes is not yet handled."
        node_b_to_matched_subgraph_a[node_end_b] = (node_start_a, node_end_a)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        if node_b.op == 'call_module' and is_activation_post_process(
                modules[node_b.target]):
            # remove activation post process node
            env_c[node_b.name] = env_c[node_b.args[0].name]  # type: ignore

        elif node_b in node_b_to_matched_subgraph_a:
            node_start_a, node_end_a = node_b_to_matched_subgraph_a[node_b]
            if False:
                print('b')
                print_node(node_b)
                print('a')
                print_node(node_start_a)
                print_node(node_end_a)

            # ensure env_c is populated with base node
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
            node_c = env_c[node_b.name]

            # after this point,
            #
            # node_a is the original node from graph_a, with parent module gm_a
            # node_b is the original node from graph_b, with parent module gm_b
            # node_c is the copy of node_b in graph_c
            #
            # subgraph so far:
            #
            # prev_node_c -> node_c

            # cast dtype from the dtype of node_c's input to the dtype of
            # node_a's input (dequant, etc)
            dtype_cast_node = _insert_dtype_cast_after_node(
                node_start_a, node_c, node_c.args[0], gm_a, gm_b, graph_c,
                node_b.name + '_dtype_cast_')
            env_c[dtype_cast_node.name] = dtype_cast_node
            # subgraph so far:
            #
            #       dtype_cast_node
            #      /
            # prev_node_c -> node_c

            # hook up the new mod_a copy to be in the graph, receiving the
            # same inputs as mod_b does, with dtype cast to match a
            node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                env_c[dtype_cast_node.name], node_start_a, node_end_a, gm_a,
                gm_b, node_c.name + '_shadow_copy_')
            env_c[node_a_shadows_c.name] = node_a_shadows_c
            # subgraph so far:
            #
            #       dtype_cast_node --> subgraph_a_copy(args/kwargs not shown)
            #      /
            # prev_node_c -> node_c

            # hook up a logger to the mod_b copy
            env_c[node_b.name] = _insert_logger_after_node(
                env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_', name_b)
            # subgraph so far:
            #
            #       dtype_cast_node --> subgraph_a_copy
            #      /
            # prev_node_c -> node_c --> logger_c

            # hook up a logger to the mod_a copy
            # Note: we pass node_b.name to this logger, for easy matching later
            env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                env_c[node_a_shadows_c.name], gm_b, logger_cls,
                '_ns_logger_a_', name_a, node_b.name)
            # subgraph so far:
            #
            #       dtype_cast_node --> subgraph_a_copy --> logger_a
            #      /
            # prev_node_c -> node_c --> logger_c

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c
示例#2
0
def remove_observers_add_loggers(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, str],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, str],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, removes all observers, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if node.op == 'call_module' and is_activation_post_process(modules[node.target]):
            # remove activation post process node
            env[node.name] = env[node.args[0].name]

        elif (
            (node in node_to_instrument_inputs_to_ref_node_name) or
            (node in node_to_instrument_outputs_to_ref_node_name)
        ):

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name = node_to_instrument_inputs_to_ref_node_name[node]
                # Ops such add and mul are special because either
                # one or two of the first two arguments can be tensors,
                # and if one argument is a tensor it can be first or
                # second (x + 1 versus 1 + x).
                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
                for node_arg_idx in arg_indices_to_log:
                    node_arg = node.args[node_arg_idx]
                    if type(node_arg) == Node:
                        # create a single input logger
                        prev_node = env[node_arg.name]
                        env[node_arg.name] = _insert_logger_after_node(
                            prev_node, gm, logger_cls, '_ns_logger_', node.name,
                            model_name, ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0, index_of_arg=node_arg_idx)
                    elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
                        # create N input loggers, one for each node
                        for arg_idx, arg in enumerate(node_arg):
                            prev_node = env[arg.name]
                            env[prev_node.name] = _insert_logger_after_node(
                                prev_node, gm, logger_cls, '_ns_logger_', node.name,
                                model_name, ref_name,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=arg_idx, index_of_arg=node_arg_idx)
                    else:
                        pass

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name = node_to_instrument_outputs_to_ref_node_name[node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name], gm, logger_cls, '_ns_logger_', node.name,
                    model_name, ref_name, NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0, index_of_arg=0)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
示例#3
0
def _prepare(model: GraphModule, qconfig_dict: Any,
             node_name_to_scope: Dict[str, Tuple[str, type]],
             prepare_custom_config_dict: Optional[Dict[str, Any]],
             is_standalone_module: bool) -> ObservedGraphModule:
    """ standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    How the standalone module is observed is specified by `input_quantized_idxs` and
    `output_quantized_idxs` in the prepare_custom_config for the standalone module
    Args:
        node_name_to_scope: mapping from node name to the scope of the module which contains the node.
        The scope is a tuple of fully qualified path of the module and the type of the module
    Returns:
        model(GraphModule): prepared standalone module
        attributes:
            _standalone_module_input_quantized_idxs(List[Int]): a list of
                indexes for the graph input that is expected to be quantized,
                same as input_quantized_idxs configuration provided
                for the standalone module
            _standalone_module_output_quantized_idxs(List[Int]): a list of
                indexs for the graph output that is quantized
                same as input_quantized_idxs configuration provided
                for the standalone module
    """
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}

    additional_quant_patterns = \
        prepare_custom_config_dict.get("additional_quant_pattern", {})
    # mapping from a tuple of nodes in reverse order to uninitialized
    #   QuantizeHandler subclass. For example,
    # {
    #   # match a single node
    #   (<class 'torch.nn.modules.conv.Conv3d'>:
    #     <class 'torch.quantization.fx.quantize.ConvRelu'>),
    #   # match multiple nodes in reverse order
    #   ((<function relu at 0x7f766a7360d0>, <built-in function add>):
    #     <class 'torch.quantization.fx.quantize.Add'>),
    # }
    patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict(
        get_default_quant_patterns(), additional_quant_patterns)

    convert_dict_to_ordered_dict(qconfig_dict)
    flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
    # TODO: support regex as well
    propagate_qconfig_(model, flattened_qconfig_dict)
    if model.training:
        additional_qat_module_mapping = prepare_custom_config_dict.get(
            "additional_qat_module_mapping", {})
        qat_swap_modules(model, additional_qat_module_mapping)

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    modules = dict(model.named_modules())

    # fill qconfig_map, a map from node name to qconfig, used in find_matches
    qconfig_map = generate_qconfig_map(model, modules, model.graph,
                                       qconfig_dict, node_name_to_scope)

    # match the patterns that will get quantized
    standalone_module_name_configs = prepare_custom_config_dict.get(
        "standalone_module_name", [])
    standalone_module_class_configs = prepare_custom_config_dict.get(
        "standalone_module_class", [])

    standalone_module_names = [
        config[0] for config in standalone_module_name_configs
    ]
    standalone_module_classes = [
        config[0] for config in standalone_module_class_configs
    ]
    custom_module_classes = get_custom_module_class_keys(
        prepare_custom_config_dict, "float_to_observed_custom_module_class")
    matches = find_matches(model.graph, modules, patterns, qconfig_map,
                           standalone_module_names, standalone_module_classes,
                           custom_module_classes)

    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    run_prepare_fx_on_standalone_modules(model, modules, matches,
                                         prepare_custom_config_dict)

    result_node = insert_observers_for_model(model, modules, matches,
                                             qconfig_map, model.graph,
                                             prepare_custom_config_dict,
                                             input_quantized_idxs,
                                             output_quantized_idxs)

    save_state(model, qconfig_map, node_name_to_scope, patterns,
               prepare_custom_config_dict)
    preserved_attributes = set(
        prepare_custom_config_dict.get("preserved_attributes", []))
    model = ObservedGraphModule(model, model.graph, preserved_attributes)
    if is_standalone_module:
        assert result_node is not None
        assert isinstance(result_node.args[0], Node), \
            "standalone module only supports returning simple value currently"\
            "(not tuple, dict etc.)"
        # these inputs are observed in parent
        # converting List[int] to Tensor since module attribute is
        # Union[Tensor, Module]
        model._standalone_module_input_quantized_idxs = \
            torch.tensor(input_quantized_idxs)
        model._standalone_module_output_quantized_idxs = torch.tensor(
            output_quantized_idxs)
    return model
示例#4
0
    def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant,
                 is_standalone_module):
        """ standalone_module means it a submodule that is not inlined in parent module,
        and will be quantized separately as one unit.

        When we are preparing a standalone module:
        input of the module is observed in parent module, output of the module
        is observed in the standalone module.
        Returns:
            model(GraphModule): prepared standalone module with following attributes:
                _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that
                                         needs to be observed in parent module
                _output_is_observed(Bool): a boolean variable indicate whether the output of the
                                   custom module is observed or not
        """
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        if self.is_dynamic_quant:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            self._qat_swap_modules(model)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = qconfig_dict.get('standalone_module_name',
                                                   None)
        matches = self._find_matches(model.graph, self.modules, self.patterns,
                                     standalone_module_names)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix(
            'activation_post_process_')

        for node in model.graph.nodes:
            if node.name in observed_node_names_set:
                continue

            prefix = node.name + '_activation_post_process_'
            root_node, _, obj, qconfig = matches.get(node.name,
                                                     (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                if qconfig is None:
                    continue

                def insert_observer(node, observer, device):
                    get_new_observer_name = get_new_attr_name_with_prefix(
                        prefix)
                    observer_name = get_new_observer_name(model)
                    setattr(model, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed_node_names_set.add(node.name)
                    if device:
                        getattr(model, observer_name).to(device)

                if isinstance(obj, CustomModuleQuantizeHandler):
                    custom_module = self.modules[node.target]
                    observed_custom_module_class = \
                        get_observed_custom_module_class(type(custom_module))
                    observed_custom_module = \
                        observed_custom_module_class.from_float(custom_module)
                    mark_observed_custom_module(observed_custom_module,
                                                type(custom_module))
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name,
                            observed_custom_module)

                # index for input of custom module that needs to be observed in parent
                standalone_module_input_idxs = None
                if isinstance(obj, StandaloneModuleQuantizeHandler):
                    # observe standalone module
                    standalone_module = self.modules[node.target]
                    traced_standalone_module = symbolic_trace(
                        standalone_module)
                    if self.is_dynamic_quant:
                        prepare = torch.quantization.quantize_fx._prepare_dynamic_standalone_module_fx
                    else:
                        prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx
                    observed_standalone_module = prepare(
                        traced_standalone_module, {'': qconfig})
                    observed_standalone_module.qconfig = qconfig
                    standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
                    observed_standalone_module = mark_observed_standalone_module(
                        observed_standalone_module)
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name,
                            observed_standalone_module)
                    self.modules[node.target] = observed_standalone_module

                # don't need to insert observer for output in dynamic quantization
                if self.is_dynamic_quant:
                    continue

                # inserting observers for output of observed module, or mark the output
                # as observed
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))

                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif (isinstance(obj, Add)
                      or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed_node_names_set:
                        observed_node_names_set.add(node.name)
                elif isinstance(obj, StandaloneModuleQuantizeHandler):
                    assert node.op == 'call_module'
                    output_is_observed = self.modules[
                        node.target]._output_is_observed
                    if output_is_observed:
                        observed_node_names_set.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    insert_observer(node, new_observer, device)

                # insert observer for input of standalone module
                if standalone_module_input_idxs is not None:
                    for idx in standalone_module_input_idxs:
                        if node.args[idx].name not in observed_node_names_set:
                            new_observer = qconfig.activation()
                            device = assert_and_get_unique_device(model)
                            insert_observer(node.args[idx], new_observer,
                                            device)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed_node_names_set and node.name in quants:
                if is_standalone_module and node.name in graph_inputs:
                    # we'll insert observer for input of standalone module
                    # in parent graph
                    standalone_module_observed_input_idxs.append(
                        graph_inputs.index(node.name))
                    continue
                get_new_observer_name = get_new_attr_name_with_prefix(prefix)
                observer_name = get_new_observer_name(model)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    # TODO: use insert_observer
                    new_observer = \
                        qconfig.weight() if is_weight else qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    if device:
                        new_observer.to(device)
                    self.activation_post_process_map[node.name] = new_observer
                    setattr(model, observer_name,
                            self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed_node_names_set.add(node.name)

        observed_graph.output(load_arg(model.graph.result))
        model = GraphModule(model, observed_graph)
        self.save_state(model)
        if is_standalone_module:
            assert isinstance(model.graph.result, Node), \
                'standalone module returning dict is not yet supported'
            # indicator for whether output is observed or not.
            # This used for correctly quantize standalone modules
            output_is_observed = model.graph.result.name in observed_node_names_set
            model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs
            model._output_is_observed = output_is_observed
        return model
示例#5
0
def scale_weight_functional(
    op_node: Node,
    model: GraphModule,
    modules: Dict[str, nn.Module],
    equalization_scale: torch.Tensor,
    next_equalization_scale: Optional[torch.Tensor],
) -> None:
    """ Scales the weight value for functional layers
    """

    # From the given op_node, the path looks like:
    #   get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
    # So we want to trace back from the op_node to get the equalization observer
    # node, then the quantization observer node, and then finally the weight
    # node which contains the weight values.

    # Get the equalization observer node
    weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
    if weight_eq_obs_node is None:
        return

    # Get the quantization observer node
    weight_quant_obs_node = weight_eq_obs_node.args[0]
    if weight_quant_obs_node is None:
        return
    assert(isinstance(weight_quant_obs_node, Node) and
           isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase))

    # Get the get_attr(weight) node
    weight_node = weight_quant_obs_node.args[0]
    if weight_node is None:
        return
    assert(isinstance(weight_node, Node) and weight_node.op == 'get_attr')

    weight_parent_name, weight_name = _parent_name(weight_node.target)
    weight = getattr(modules[weight_parent_name], weight_name)

    # Scale the weights for input-weight equalization
    # If the following layer needs to be equalized then we will multiply its scale
    scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale))

    if next_equalization_scale is None:
        setattr(modules[weight_parent_name], weight_name, scaled_weight)
        return

    # Multiply the weights row wise by the next equalization scale
    new_shape = [1] * weight.ndim
    new_shape[0] = weight.size(0)
    scaled_weight = torch.mul(scaled_weight, next_equalization_scale.view(new_shape))

    setattr(modules[weight_parent_name], weight_name, scaled_weight)
    assert(torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight))

    # Multiply the bias element wise by the next equalization scale
    bias_node = None
    for node, _ in op_node.users.items():
        # Find the node containing the weight values
        if node.op == 'get_attr' and 'bias' in node.name:
            bias_node = node
            break
    if bias_node is None:
        return

    bias_parent_name, bias_name = _parent_name(bias_node.target)
    bias = getattr(modules[bias_parent_name], bias_name)

    scaled_bias = torch.mul(bias, next_equalization_scale)
    setattr(modules[bias_parent_name], bias_name, scaled_bias)
示例#6
0
def _prepare_fx(
    model: torch.nn.Module,
    qconfig_dict: Any,
    prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
    equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
    backend_config_dict: Optional[Dict[str, Any]] = None,
    is_standalone_module: bool = False,
    is_qat: bool = False,
) -> ObservedGraphModule:
    r""" Internal helper function for prepare_fx
    Args:
      `model`, `qconfig_dict`, `prepare_custom_config_dict`, `equalization_qonfig_dict`:
      see docs for :func:`~torch.ao.quantization.prepare_fx`
      `is_standalone_module`: a boolean flag indicates whether we are
      quantizing a standalone module or not, a standalone module
      is a submodule of the parent module that is not inlined in the
forward graph of the parent module,
      the way we quantize standalone module is described in:
      :func:`~torch.ao.quantization._prepare_standalone_module_fx`
    """
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}
    if equalization_qconfig_dict is None:
        equalization_qconfig_dict = {}

    check_is_valid_qconfig_dict(qconfig_dict)
    check_is_valid_prepare_custom_config_dict(prepare_custom_config_dict)
    check_is_valid_qconfig_dict(equalization_qconfig_dict)

    skipped_module_names = prepare_custom_config_dict.get(
        "non_traceable_module_name", [])
    skipped_module_classes = prepare_custom_config_dict.get(
        "non_traceable_module_class", [])

    # swap FloatFunctional with FXFloatFunctional
    _swap_ff_with_fxff(model)

    # symbolically trace the model
    if not is_standalone_module:
        # standalone module and custom module config are applied in top level module
        standalone_module_name_configs = prepare_custom_config_dict.get(
            "standalone_module_name", [])
        skipped_module_names += [
            config[0] for config in standalone_module_name_configs
        ]

        standalone_module_class_configs = prepare_custom_config_dict.get(
            "standalone_module_class", [])
        skipped_module_classes += [
            config[0] for config in standalone_module_class_configs
        ]
        float_custom_module_classes = get_custom_module_class_keys(
            prepare_custom_config_dict,
            "float_to_observed_custom_module_class")
        skipped_module_classes += float_custom_module_classes

    preserved_attributes = prepare_custom_config_dict.get(
        "preserved_attributes", [])
    tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
    graph_module = GraphModule(model, tracer.trace(model))
    for attr_name in preserved_attributes:
        setattr(graph_module, attr_name, getattr(model, attr_name))
    graph_module = _fuse_fx(graph_module, prepare_custom_config_dict,
                            backend_config_dict)
    prepared = prepare(
        graph_module,
        qconfig_dict,
        tracer.node_name_to_scope,
        prepare_custom_config_dict=prepare_custom_config_dict,
        equalization_qconfig_dict=equalization_qconfig_dict,
        backend_config_dict=backend_config_dict,
        is_standalone_module=is_standalone_module,
        is_qat=is_qat,
    )

    for attr_name in preserved_attributes:
        setattr(prepared, attr_name, getattr(model, attr_name))
    return prepared
def _convert_do_not_use(
        model: GraphModule, is_reference: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None,
        is_standalone_module: bool = False,
        _remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
    """
    We will convert an observed model (a module with observer calls) to a reference
    quantized model, the rule is simple:
    1. for each observer module call in the graph, we'll convert it to calls to
       quantize and dequantize functions based on the observer instance
    2. for weighted operations like linear/conv, we need to convert them to reference
       quantized module, this requires us to know whether the dtype configured for the
       weight is supported in the backend, this is done in prepare step and the result
       is stored in observed_node_names, we can decide whether we need to swap the
       module based on this set

    standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model)
    qconfig_map: Dict[str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    assert is_reference, "_convert_do_not_use only supports reference option"

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    matches = find_matches(
        model.graph, modules, patterns,
        qconfig_map,
        custom_module_classes=custom_module_classes)

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def replace_observer_with_quantize_dequantize_node(graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        root_module = modules[""]
        if observer_module.dtype == torch.float32:
            # remove the node for now
            # TODO: support dynamic quant
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        elif observer_module.dtype in [torch.quint8, torch.qint8, torch.float16]:
            node_type, quantize_op, qparams = get_quantize_node_info(observer_module)
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(root_module, graph, key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)


    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                # Inputs are assumed to be quantized if the user specifid the
                # input_quantized_idxs override.
                # Note: we don't need to do anything for this, it affects prepare
                # step in terms of whether to insert observer for input or not
                continue
        elif node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                # Remove the dequantize operator in the end
                maybe_dequantize_node = node.args[0]
                if isinstance(maybe_dequantize_node, Node) and \
                   maybe_dequantize_node.op == "call_method" and \
                   maybe_dequantize_node.target == "dequantize":
                    quantized_node = maybe_dequantize_node.args[0]
                    maybe_dequantize_node.replace_all_uses_with(quantized_node)
                    model.graph.erase_node(maybe_dequantize_node)
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                replace_observer_with_quantize_dequantize_node(model.graph, node, modules)
            elif type(modules[node.target]) in set(
                    WEIGHTED_MODULE_CLASSES).union(QAT_MODULE_CLASSES).union(FUSED_MODULE_CLASSES):
                # TODO: refactor this part to a function
                original_module = modules[node.target]
                qconfig = original_module.qconfig

                is_observed = node.name in observed_node_names
                is_weight_quantized = weight_is_statically_quantized(qconfig)
                # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
                if qconfig is None or not is_observed or not is_weight_quantized:
                    continue

                float_module = original_module
                fused_module = None
                if isinstance(
                        original_module,
                        QAT_MODULE_CLASSES):
                    # case 1. converting qat module to
                    # a float module, we need to attch
                    # weight fake_quant to the module,
                    # weight fake_quant is assumed to be run during
                    # QAT so we don't need to run it again here
                    float_module = original_module.to_float()  # type: ignore[operator]
                    # change qat conv to conv
                    parent_name, name = _parent_name(node.target)
                    setattr(modules[parent_name], name, float_module)
                    if isinstance(float_module, torch.nn.intrinsic._FusedModule):
                        fused_module = float_module
                        float_module = fused_module[0]
                    weight_post_process = original_module.weight_fake_quant
                else:
                    # case 2. converting a float module/fused float module
                    # to float module, we need to attach
                    # weight observer to the conv module and run it
                    # with conv weight
                    if isinstance(original_module, torch.nn.intrinsic._FusedModule):
                        fused_module = original_module
                        float_module = fused_module[0]  # type: ignore[index]
                    assert qconfig is not None
                    weight_post_process = qconfig.weight()
                    # run weight observer
                    weight_post_process(float_module.weight)  # type: ignore[operator]
                weight_qparams = get_qparam_dict(weight_post_process)
                ref_qmodule_cls = get_static_quant_module_class(type(float_module), is_reference=True)
                ref_qmodule = ref_qmodule_cls.from_float(float_module, weight_qparams)
                if fused_module is not None:
                    fused_module[0] = ref_qmodule
                else:
                    parent_name, name = _parent_name(node.target)
                    setattr(modules[parent_name], name, ref_qmodule)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, model.graph, preserved_attributes)
    return model
示例#8
0
    def _convert(self,
                 model: GraphModule,
                 debug: bool = False,
                 convert_custom_config_dict: Dict[str, Any] = None,
                 is_standalone_module: bool = False) -> GraphModule:
        """ standalone_module means it a submodule that is not inlined in
        parent module, and will be quantized separately as one unit.

        Returns a quantized standalone module which accepts float input
        and produces float output.
        """
        if convert_custom_config_dict is None:
            convert_custom_config_dict = {}
        self.restore_state(model)
        # always run weight observers in the top level forward method
        # for dynamic quant ops or weight only quant ops
        self._run_weight_observers(model)

        # move to cpu since we only have quantized cpu kernels
        model.eval().cpu()
        self.modules = dict(model.named_modules())

        custom_module_classes = get_custom_module_class_keys(
            convert_custom_config_dict,
            "observed_to_quantized_custom_module_class")
        assert self.patterns is not None
        matches = self._find_matches(
            model.graph,
            self.modules,
            self.patterns,
            custom_module_classes=custom_module_classes)

        quants = self._find_quants(model.graph, matches)

        self.quantized_graph = Graph()
        env: Dict[Any, Any] = {}
        quant_env: Dict[Any, Any] = {}

        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find ' + \
                    'node:' + n.name + \
                    ' in quantized or non quantized environment, env: ' + \
                    str(env) + ' quant_env:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + \
                    n.name + ' in float environment:' + str(env)
                assert n.name in quants, \
                    'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and
                the args with corresponding indexes will be quantized
              - if quantized is a boolean, then all args will be
                quantized/not quantized
              - if quantized is None, then we'll load the node as long as it
                exists

            Output: fn which takes arg_or_args, and loads them from the
                corresponding environment depending on the value of quantized.
            """
            assert quantized is None or \
                isinstance(quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        arg_or_args,
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)

            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, \
                    'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but
                # quant_env will take precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception(
                        "partially quantized inputs in list not handled yet")

        def is_output_quantized(node) -> bool:
            """ Check if output node is quantized or not """
            assert self.modules is not None
            # by default the output is expected to be quantized
            quantized = True

            # Need to get correct quantized/non-quantized state for the output
            # of CopyNode
            if type(obj) in [CopyNode, FixedQParamsOpQuantizeHandler]:
                assert node.op in [
                    'call_module',
                    'call_function',
                    'call_method'], \
                    'CopyNode of type ' + node.op + ' is not handled'
                quantized = is_quantized(node.args[0])

            if not activation_is_statically_quantized(qconfig) or \
               not input_output_observed(obj):
                quantized = False

            return quantized

        def insert_quantize_node(node):
            """ Given a activation_post_process module call node, insert a
            quantize node"""
            assert self.modules is not None
            observer_module = self.modules[node.target]
            prev_node = node.args[0]
            if observer_module.dtype == torch.float16:
                # activations are not quantized for
                # fp16 dynamic quantization
                # copy the activaiton_post_process node here
                # since we may need it when we insert prepack
                # op for weight of linear, this will be removed
                # later in a separate pass
                env[node.name] = self.quantized_graph.node_copy(
                    node, load_non_quantized)
            elif prev_node.name in quant_env:
                # if previous node is already quantized, we'll just remove the
                # activation_post_process
                quant_env[node.name] = quant_env[prev_node.name]
            else:
                # replace activation post process with quantization ops
                root_module = self.modules[""]
                quant_env[node.name] = quantize_node(
                    root_module, self.quantized_graph,
                    load_non_quantized(node.args[0]), observer_module)

        # additional state to override inputs to be quantized, if specified
        # by the user
        placeholder_node_seen_cnt = 0
        output_node_seen_cnt = 0
        input_quantized_idxs: List[int] = convert_custom_config_dict.get(
            "input_quantized_idxs", [])
        output_quantized_idxs: List[int] = convert_custom_config_dict.get(
            "output_quantized_idxs", [])

        for node in model.graph.nodes:
            if node.op == 'output':
                cur_output_node_idx = output_node_seen_cnt
                output_node_seen_cnt += 1
                if cur_output_node_idx in output_quantized_idxs:
                    # Result are kept quantized if the user specified the
                    # output_quantized_idxs override.
                    graph_output = map_arg(node.args[0], load_x)
                else:
                    graph_output = map_arg(node.args[0], load_non_quantized)
                self.quantized_graph.output(graph_output)
                continue
            root_node, matched, matched_pattern, obj, qconfig = \
                matches.get(node.name, (None, None, None, None, None))
            if root_node is node:
                if qconfig is None:
                    result = self.quantized_graph.node_copy(
                        node, load_non_quantized)
                    quantized = False
                else:
                    assert obj is not None
                    is_standalone_module_node = is_observed_standalone_module_node(
                        node, self.modules)
                    result = obj.convert(
                        self,
                        node,
                        load_arg,
                        debug=debug,
                        convert_custom_config_dict=convert_custom_config_dict)
                    if is_standalone_module_node:
                        quantized = False
                    else:
                        quantized = is_output_quantized(node)

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module' and \
                    is_activation_post_process(self.modules[node.target]):
                insert_quantize_node(node)
            elif node.op == 'placeholder':
                cur_placeholder_node_idx = placeholder_node_seen_cnt
                placeholder_node_seen_cnt += 1
                if cur_placeholder_node_idx in input_quantized_idxs:
                    quant_env[node.name] = \
                        self.quantized_graph.node_copy(node, load_non_quantized)
                else:
                    env[node.name] = \
                        self.quantized_graph.node_copy(node, load_non_quantized)
            else:
                # copy quantized or non-quantized node
                env[node.name] = \
                    self.quantized_graph.node_copy(node, load_non_quantized)

        # remove activation post process
        act_post_process_removed_graph = Graph()
        env = {}

        def load_arg(a):  # type: ignore
            return map_arg(a, lambda node: env[node.name])

        for node in self.quantized_graph.nodes:
            if node.op == 'output':
                act_post_process_removed_graph.output(
                    map_arg(node.args[0], load_arg))
                continue
            if node.op == 'call_module' and \
               is_activation_post_process(self.modules[node.target]):
                # remove activation post process node
                env[node.name] = env[node.args[0].name]
            else:
                env[node.name] = act_post_process_removed_graph.node_copy(
                    node, load_arg)

        # removes qconfig and activation_post_process modules
        _remove_qconfig(model)
        model = GraphModule(model, act_post_process_removed_graph)
        return model
示例#9
0
        tanh_1 = torch.tanh(cat_1);  cat_1 = None
        neg_1 = torch.neg(tanh_1);  tanh_1 = None
        return neg_1

"""

# Create a graph independently of symbolic tracing
graph = Graph()

# Create raw Nodes
raw1 = graph.placeholder("x")
raw2 = graph.placeholder("y")

# Initialize Proxies using the raw Nodes
y = Proxy(raw1)
z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)

# Create a new output Node and add it to the Graph. By doing this, the
# Graph will contain all the Nodes we just created (since they're all
# linked to the output Node)
graph.output(c.node)

# Wrap our created Graph in a GraphModule to get a final, runnable
# `nn.Module` instance
mod = GraphModule(torch.nn.Module(), graph)
示例#10
0
    def _prepare(self, model: GraphModule, qconfig_dict: Any,
                 prepare_custom_config_dict: Optional[Dict[str, Any]],
                 is_standalone_module: bool) -> GraphModule:
        """ standalone_module means it a submodule that is not inlined in
        parent module, and will be quantized separately as one unit.

        When we are preparing a standalone module:
        both input and output are observed in prepared standalone module
        Returns:
            model(GraphModule): prepared standalone module
        """
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}

        additional_quant_patterns = \
            prepare_custom_config_dict.get("additional_quant_pattern", {})
        self.patterns = get_combined_dict(get_default_quant_patterns(),
                                          additional_quant_patterns)

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            additional_qat_module_mapping = prepare_custom_config_dict.get(
                "additional_qat_module_mapping", {})
            self._qat_swap_modules(model, additional_qat_module_mapping)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = prepare_custom_config_dict.get(
            "standalone_module_name", None)
        standalone_module_classes = prepare_custom_config_dict.get(
            "standalone_module_class", None)
        custom_module_classes = get_custom_module_class_keys(
            prepare_custom_config_dict,
            "float_to_observed_custom_module_class")
        assert self.patterns is not None
        matches = self._find_matches(model.graph, self.modules, self.patterns,
                                     standalone_module_names,
                                     standalone_module_classes,
                                     custom_module_classes)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuantizeHandler object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env: Dict[Any, Any] = {}
        observed_graph = Graph()
        observed_node_names_set: Set[str] = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs: List[int] = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix(
            'activation_post_process_')
        model_device = assert_and_get_unique_device(model)

        result_node: Optional[Node] = None
        for node in model.graph.nodes:
            if node.op == 'output':
                observed_graph.output(load_arg(node.args[0]))
                result_node = node
                continue
            if node.name in observed_node_names_set:
                continue

            root_node, matched_nodes, pattern, obj, qconfig = matches.get(
                node.name, (None, None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                # index for input of custom module that needs to be observed in
                # parent
                if qconfig is not None:
                    assert obj is not None
                    insert_observer_for_special_module(
                        obj, self.modules, prepare_custom_config_dict, qconfig,
                        node)
                    insert_observer_for_output_of_the_node(
                        node, obj, qconfig, self.modules, model, pattern,
                        model_device, self.activation_post_process_map, env,
                        observed_graph, load_arg, observed_node_names_set,
                        matched_nodes)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            insert_observer_for_input_arg_of_observed_node(
                node, observed_node_names_set, quants, model_device, model,
                self.activation_post_process_map, env, observed_graph,
                load_arg)

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        model = mark_observed_module(model)
        return model
示例#11
0
 def save_state(self, observed: GraphModule) -> None:
     observed._activation_post_process_map = \
         self.activation_post_process_map  # type: ignore
     observed._patterns = self.patterns  # type: ignore
     observed._qconfig_map = self.qconfig_map  # type: ignore
示例#12
0
    def call(self, graph_module: GraphModule) -> PassResult:
        """
        Return a new copy of torch.fx.GraphModule with CSE applied to the input graph

        Example usage:

        from torch.fx.experimental.proxy_tensor import make_fx
        def f(a):
            b = a * a
            c = a * a
            return b+c

        p = CSEPass()
        traced_graph = make_fx(f)(torch.tensor(1))
        print(traced_graph)
        result = p(traced_graph)
        print(result.graph_module)
        """
        def get_aten_target(node):
            if hasattr(node.target, 'overloadpacket'):
                return node.target.overloadpacket
            return node.target

        modified = False
        new_graph = Graph()
        env: Dict[Node, Node] = {
        }  # map from node in the old graph to node in the new graph
        hash_env: Dict[Tuple[torch._ops.OpOverload, int],
                       Node] = {}  # map from hash to a node in the new graph
        token_map: Dict[Tuple[torch._ops.OpOverload, int],
                        Dict[str, Any]] = {}  # map from hash to token
        for n in graph_module.graph.nodes:
            # The placeholder, output, and get_attr nodes are copied to the new grpah without change
            # do not CSE away random operations
            if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(
                    n) in self.banned_ops:
                new_node = new_graph.node_copy(n, lambda x: env[x])
                env[n] = new_node
            else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
                # substitute args and kwargs memebrs to their mapping in env if exists
                # specs can be used to reconstruct nested list/dictionaries
                def substitute(arg_list):
                    arg_list, spec = tree_flatten(arg_list)
                    for i in range(len(arg_list)):
                        v = arg_list[i]
                        if isinstance(v, Node) and v in env:
                            arg_list[i] = env[v]
                    return tuple(arg_list), spec

                args, args_spec = substitute(n.args)
                kwargs, kwargs_spec = substitute(n.kwargs)

                # each token corresponds to a unique node
                # nodes with the same token can be substituted
                token = {
                    "target": n.target,
                    "args": args,
                    "args_spec": args_spec,
                    "kwargs": kwargs,
                    "kwargs_spec": kwargs_spec
                }

                # hash substituted args to a number, do not hash specs because specs are not hashable
                hash_arg = hash((args, kwargs))
                hash_val = (n.target, hash_arg)

                # check if a node has a substitute and can be eliminated
                hash_val_in_hash_env = hash_val in hash_env
                if hash_val_in_hash_env and token_map[hash_val] == token:
                    modified = True  # substition happens and the graph is modified
                    env[n] = hash_env[hash_val]
                    continue

                new_node = new_graph.node_copy(n, lambda x: env[x])
                env[n] = new_node
                if not hash_val_in_hash_env:
                    hash_env[hash_val] = new_node
                    token_map[hash_val] = token

        csed_gm = GraphModule(graph_module, new_graph)
        return PassResult(csed_gm, modified)
示例#13
0
    def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, is_standalone_module):
        """ standalone_module means it a submodule that is not inlined in parent module,
        and will be quantized separately as one unit.

        When we are preparing a standalone module:
        input of the module is observed in parent module, output of the module
        is observed in the standalone module.
        Returns:
            model(GraphModule): prepared standalone module with following attributes:
                _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that
                                         needs to be observed in parent module
                _output_is_observed(Bool): a boolean variable indicate whether the output of the
                                   custom module is observed or not
        """
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}

        additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {})
        self.patterns = get_combined_dict(get_default_quant_patterns(), additional_quant_patterns)

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            additional_qat_module_mapping = prepare_custom_config_dict.get("additioanl_qat_module_mapping", {})
            self._qat_swap_modules(model, additional_qat_module_mapping)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = prepare_custom_config_dict.get("standalone_module_name", None)
        standalone_module_classes = prepare_custom_config_dict.get("standalone_module_class", None)
        custom_module_classes = get_custom_module_class_keys(prepare_custom_config_dict, "float_to_observed_custom_module_class")
        matches = self._find_matches(
            model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuantizeHandler object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_')
        model_device = assert_and_get_unique_device(model)

        def insert_observer(node, observer):
            """Insert observer for node by modifying the observed_graph and
               attach observer module to the model
               Args:
                 node: Node
                 observer: observer/fake_quantize module instance
            """
            # respect device affinity when adding observers
            if model_device:
                observer.to(model_device)
            # add observer module as attribute
            prefix = node.name + '_activation_post_process_'
            get_new_observer_name = get_new_attr_name_with_prefix(prefix)
            observer_name = get_new_observer_name(model)
            setattr(model, observer_name, observer)
            # put observer instance activation_post_process map
            self.activation_post_process_map[node.name] = observer
            # insert observer call
            env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
            observed_node_names_set.add(node.name)

        result_node : Optional[Node] = None
        for node in model.graph.nodes:
            if node.op == 'output':
                observed_graph.output(load_arg(node.args[0]))
                result_node = node
                continue
            if node.name in observed_node_names_set:
                continue

            root_node, matched_nodes, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                # index for input of custom module that needs to be observed in parent
                standalone_module_input_idxs = None
                if qconfig is not None:
                    if isinstance(obj, CustomModuleQuantizeHandler):
                        custom_module = self.modules[node.target]
                        custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
                        observed_custom_module_class = \
                            get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig)
                        observed_custom_module = \
                            observed_custom_module_class.from_float(custom_module)
                        parent_name, name = _parent_name(node.target)
                        setattr(self.modules[parent_name], name, observed_custom_module)

                    elif isinstance(obj, StandaloneModuleQuantizeHandler):
                        # observe standalone module
                        standalone_module = self.modules[node.target]
                        prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx
                        observed_standalone_module = prepare(standalone_module, {'': qconfig})
                        observed_standalone_module.qconfig = qconfig
                        standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
                        observed_standalone_module = mark_observed_standalone_module(observed_standalone_module)
                        parent_name, name = _parent_name(node.target)
                        setattr(self.modules[parent_name], name, observed_standalone_module)
                        self.modules[node.target] = observed_standalone_module


                    # don't need to insert observer for output if activation does not
                    # need to be statically quantized
                    if activation_is_statically_quantized(qconfig):
                        if isinstance(obj, FixedQParamsOpQuantizeHandler) and model.training:
                            # we only insert fake quantize module in qat
                            activation_post_process_ctr = \
                                get_default_output_activation_post_process_map().get(pattern, None)
                            assert activation_post_process_ctr is not None, \
                                "activation_post_process constructor not provided for " + \
                                "pattern:" + str(pattern)
                            insert_observer(node, activation_post_process_ctr())
                        elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and
                              not model.training) or isinstance(obj, CopyNode):
                            # inserting observers for output of observed module, or mark the output
                            # as observed
                            assert node.op in [
                                'call_module',
                                'call_function',
                                'call_method'], \
                                'CopyNode of type ' + node.op + ' is not handled'

                            def is_observed(input_arg):
                                if isinstance(input_arg, Node):
                                    return input_arg.name in observed_node_names_set
                                elif isinstance(input_arg, list):
                                    return all(map(is_observed, input_arg))
                            # propagate observed property from input
                            if is_observed(node.args[0]):
                                observed_node_names_set.add(node.name)
                        elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1:
                            input_node = matched_nodes[-1]  # first node in the sequence

                            def input_is_observed(arg):
                                return isinstance(arg, Node) and arg.name in observed_node_names_set
                            # This is checking if one of the argument of add/mul
                            # is an observed node
                            # If both of the inputs are number,
                            # we will not consider the output to be observed
                            if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
                                observed_node_names_set.add(node.name)
                        elif isinstance(obj, StandaloneModuleQuantizeHandler):
                            assert node.op == 'call_module'
                            output_is_observed = self.modules[node.target]._output_is_observed
                            if output_is_observed:
                                observed_node_names_set.add(node.name)
                        elif obj.all_node_args:
                            # observer for outputs
                            new_observer = qconfig.activation()
                            insert_observer(node, new_observer)

                    # insert observer for input of standalone module
                    if standalone_module_input_idxs is not None:
                        for idx in standalone_module_input_idxs:
                            if node.args[idx].name not in observed_node_names_set:
                                new_observer = qconfig.activation()
                                insert_observer(node.args[idx], new_observer)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            # insert observer for output of the node
            if node.name not in observed_node_names_set and node.name in quants:
                if is_standalone_module and node.name in graph_inputs:
                    # we'll insert observer for input of standalone module
                    # in parent graph
                    standalone_module_observed_input_idxs.append(graph_inputs.index(node.name))
                    continue
                _, activation_post_process_ctr = quants[node.name]
                if activation_post_process_ctr is not None:
                    insert_observer(node, activation_post_process_ctr())

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        model = mark_observed_module(model)
        if is_standalone_module:
            assert result_node is not None
            assert isinstance(result_node.args[0], Node), \
                'standalone module returning dict is not yet supported'
            # indicator for whether output is observed or not.
            # This used for correctly quantize standalone modules
            output_is_observed = result_node.args[0].name in observed_node_names_set
            model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs
            model._output_is_observed = output_is_observed
        return model
示例#14
0
def convert(model: GraphModule,
            is_reference: bool = False,
            convert_custom_config_dict: Dict[str, Any] = None,
            is_standalone_module: bool = False,
            _remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
    """ standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict = restore_state(
        model)
    qconfig_map: Dict[
        str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]
    # always run weight observers in the top level forward method
    # for dynamic quant ops or weight only quant ops
    run_weight_observers(model)

    # move to cpu since we only have quantized cpu kernels
    model.eval().cpu()
    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    matches = find_matches(model.graph,
                           modules,
                           patterns,
                           qconfig_map,
                           custom_module_classes=custom_module_classes)

    quantized_graph = Graph()
    env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {}

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def load_non_quantized(n: Node) -> Node:
        assert n.name in env, \
            'trying to load float node but did not find ' + \
            'node:' + n.name + \
            ' in env: ' + \
            str(env)
        quantized_node, dtype = env[n.name]
        if dtype and dtype != torch.float:
            env[n.name] = Proxy(quantized_node).dequantize().node, torch.float
        return env[n.name][0]

    def load_quantized(n: Node) -> Node:
        assert n.name in env, \
            'trying to load quantized node but did not find node:' + \
            n.name + ' in environment:' + str(env)
        quantized_node, dtype = env[n.name]
        assert dtype in [torch.quint8, torch.qint8, torch.float16], \
            f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}'
        return quantized_node

    def load_x(n: Node) -> Node:
        assert n.name in env, \
            'node ' + n.name + ' does not exist in environment'
        return env[n.name][0]

    def load_arg(
        quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
    ) -> Callable[[Node], Argument]:
        """
        Input: quantized, which can be None, list, boolean or tuple
          - if quantized is None, then we'll load the node as long as it
            exists
          - if quantized is a boolean, then all args will be
            quantized/not quantized
          - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False)
          - if quantized is a list or tuple, then arg should be a list and
            the args with corresponding indexes will be quantized


        Output: fn which takes arg_or_args, and loads them from the
            corresponding environment depending on the value of quantized.
        """
        assert quantized is None or \
            isinstance(quantized, (tuple, list, bool)), type(quantized)
        if isinstance(quantized, (tuple, list)) and len(quantized) == 0:
            # empty tuple or list means nothing is quantized
            quantized = False

        def load_arg_impl(arg_or_args):
            # we'll update the format of `quantized`
            # to better match arg_or_args
            updated_quantized: Optional[Union[List[int], bool,
                                              Tuple[int, ...]]] = quantized

            if isinstance(quantized, (tuple, list)) and \
               len(quantized) == 1 and isinstance(arg_or_args, Node):
                # when argument is one Node instead of tuple, we just need to check
                # 0 is in the quantized list
                updated_quantized = 0 in quantized

            if updated_quantized is None:
                return map_arg(arg_or_args, load_x)
            if isinstance(updated_quantized, bool):
                return map_arg(
                    arg_or_args, load_quantized
                    if updated_quantized else load_non_quantized)
            elif isinstance(updated_quantized, (tuple, list)):
                assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                loaded_args = []
                # for now, we only support quantizing positional arguments
                for i, a in enumerate(arg_or_args):
                    if i in updated_quantized:
                        loaded_args.append(map_arg(a, load_quantized))
                    else:
                        loaded_args.append(map_arg(a, load_non_quantized))
                return type(arg_or_args)(loaded_args)

        return load_arg_impl

    def node_arg_is_quantized(node_arg: Any) -> bool:
        if isinstance(node_arg, Node):
            assert node_arg.name in env, \
                'Expecting node_arg to be in the environment'
            if node_arg.name in env:
                _, dtype = env[node_arg.name]
                return dtype != torch.float
            else:
                return False
        elif isinstance(node_arg, list):
            quantized = map(node_arg_is_quantized, node_arg)
            if all(quantized):
                return True
            elif not any(quantized):
                return False
            else:
                raise Exception(
                    "partially quantized inputs in list not handled yet")
        else:
            return False

    def is_output_quantized(node: Node, obj: QuantizeHandler,
                            qconfig: QConfigAny,
                            modules: Dict[str, torch.nn.Module]) -> bool:
        """ Check if output node is quantized or not """
        assert modules is not None
        # by default the output for a quantizable node is expected to be quantized
        quantized = True

        # Need to get correct quantized/non-quantized state forn the output
        # of FixedQParamsQuantizeHandler
        # TODO: we may want to try to remove the special case here
        # as well
        if obj.should_mark_output_quantized_from_input_quantized_status(
                qconfig):
            assert node.op in [
                'call_module',
                'call_function',
                'call_method'], \
                'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
            # TODO: need to extend this to consider all relevant args instead of just arg[0]
            quantized = node_arg_is_quantized(node.args[0])

        # the output is unquantized if the node is not a CopyNode
        # or the activation is not statically quantized
        if not activation_is_statically_quantized(qconfig) or \
           not obj.input_output_observed():
            quantized = False
        if node_return_type_is_int(node):
            quantized = False

        return quantized

    def insert_quantize_node(node: Node,
                             modules: Dict[str, torch.nn.Module]) -> None:
        """ Given a activation_post_process module call node, insert a
        quantize node"""
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        prev_node = node.args[0]
        if observer_module.dtype == torch.float32:
            # copy the observer for fp32 dtype
            env[node.name] = quantized_graph.node_copy(
                node, load_non_quantized), torch.float
        elif isinstance(prev_node, Node) and prev_node.name in env:
            # if previous node is already quantized, we'll just remove the
            # activation_post_process
            _, prev_dtype = env[prev_node.name]
            current_dtype = observer_module.dtype
            if prev_dtype == current_dtype:
                env[node.name] = env[prev_node.name]
            else:
                root_module = modules[""]
                assert isinstance(prev_node, Node)
                observer_dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
                env[node.name] = (quantize_node(load_non_quantized(prev_node),
                                                observer_module,
                                                node,
                                                modules,
                                                quantized_graph,
                                                node_name_to_scope,
                                                is_input=True), observer_dtype)
        else:
            # replace activation post process with quantization ops
            root_module = modules[""]
            assert isinstance(node.args[0], Node)
            dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
            env[node.name] = (quantize_node(load_non_quantized(node.args[0]),
                                            observer_module,
                                            node,
                                            modules,
                                            quantized_graph,
                                            node_name_to_scope,
                                            is_input=True), dtype)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    for node in model.graph.nodes:
        if node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                graph_output = map_arg(node.args[0], load_x)
            else:
                graph_output = map_arg(node.args[0], load_non_quantized)
            quantized_graph.output(graph_output)
            continue
        root_node, matched, matched_pattern, obj, qconfig = \
            matches.get(node.name, (None, None, None, None, None))
        if root_node is node:
            is_observed_standalone_module_node = (
                node.op == 'call_module'
                and is_observed_standalone_module(modules[node.target]))
            if qconfig is None and not is_observed_standalone_module_node:
                result = quantized_graph.node_copy(node, load_non_quantized)
                quantized = False
            else:
                assert obj is not None
                # We will get whether the output is quantized or not before
                # convert for standalone module and after convert
                # for non-standalone module, since _standalone_module_output_quantized_idxs
                # is only available in observed standalone module
                if is_observed_standalone_module_node:
                    out_quant_idxs = modules[
                        node.
                        target]._standalone_module_output_quantized_idxs.tolist(
                        )  # type: ignore[operator] # noqa: B950
                    assert len(
                        out_quant_idxs
                    ) <= 1, "Currently standalone only support one output"
                    quantized = 0 in out_quant_idxs

                qconfig = qconfig_map[node.name]
                result = obj.convert(
                    node,
                    qconfig,
                    modules,
                    quantized_graph,
                    node_name_to_scope,
                    load_arg,
                    is_reference=is_reference,
                    convert_custom_config_dict=convert_custom_config_dict)
                if not is_observed_standalone_module_node:
                    quantized = is_output_quantized(node, obj, qconfig,
                                                    modules)

            if quantized:
                env[node.name] = result, activation_dtype(qconfig)
            else:
                env[node.name] = result, torch.float
            continue
        elif root_node is not None:
            if qconfig is None:
                # This branch is hit if all of these conditions are met:
                # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu)
                # 2. the current node is not the "root_node" of the pattern
                # 3. quantization for this pattern is disabled
                #
                # In this case, we need to make sure to populate the env with
                # intermediate nodes manually, because the QuantizeHandler.convert
                # function will not be called.
                result = quantized_graph.node_copy(node, load_non_quantized)
                env[node.name] = result, torch.float
            continue

        # handle activation post process calls
        if node.op == 'call_module' and \
                is_activation_post_process(modules[node.target]):
            insert_quantize_node(node, modules)
        elif node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                env[node.name] = \
                    quantized_graph.node_copy(
                        node, load_non_quantized), torch.quint8
            else:
                env[node.name] = \
                    quantized_graph.node_copy(node, load_non_quantized), torch.float
        else:
            # copy quantized or non-quantized node
            # get_tensor_info_node like shape works for both
            # quantized and non-quantized input and output a non-Tensor
            # (we use None for dtype currently for non-Tensors)
            if is_get_tensor_info_node(node):
                env[node.name] = \
                    quantized_graph.node_copy(node, load_x), None
            else:
                env[node.name] = \
                    quantized_graph.node_copy(node, load_non_quantized), torch.float

    # remove activation post process
    act_post_process_removed_graph = Graph()
    remove_env: Dict[str, Node] = {}

    def load_arg_remove(a: Argument) -> Argument:
        return map_arg(a, lambda node: remove_env[node.name])

    for node in quantized_graph.nodes:
        if node.op == 'output':
            act_post_process_removed_graph.output(
                map_arg(node.args[0], load_arg_remove))
            continue
        if node.op == 'call_module' and \
           is_activation_post_process(modules[node.target]):
            # remove activation post process node
            remove_env[node.name] = remove_env[node.args[0].name]
        else:
            remove_env[node.name] = act_post_process_removed_graph.node_copy(
                node, load_arg_remove)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(
        convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, act_post_process_removed_graph,
                                 preserved_attributes)
    if not is_reference:
        model = fold_weight(model, node_name_to_scope)
    return model
示例#15
0
    def _convert(self, observed, inplace=False, debug=False, is_dynamic_quant=False):
        assert not inplace, 'inplace convert is not supported yet'
        self.restore_state(observed)
        self.is_dynamic_quant = is_dynamic_quant
        # run weight observers before inserting quant dequant nodes
        # for dynamic quantization
        if self.is_dynamic_quant:
            self._run_weight_observers(observed)

        # move to cpu since we only have quantized cpu kernels
        observed.eval().cpu()
        observed_root = observed.root
        observed_graph = observed.graph
        if not inplace:
            observed_root = copy.deepcopy(observed_root)

        self.modules = dict(observed_root.named_modules())

        matches = self._find_matches(observed.graph, self.modules, self.patterns)
        quants = self._find_quants(observed.graph, matches)
        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized environment:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either of the environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            if quantized is a list, then arg should be a list and the args with corresponding
            indexes will be quantized
            if quantized is a boolean, then all args will be quantized/not quantized
            if quantized is None, then we'll load the node as long as it exists
            """
            assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg):
                if quantized is None:
                    return map_arg(arg, load_x)
                if isinstance(quantized, bool):
                    return map_arg(arg, load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg, (tuple, list)), arg
                    loaded_arg = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg):
                        if i in quantized:
                            loaded_arg.append(map_arg(a, load_quantized))
                        else:
                            loaded_arg.append(map_arg(a, load_non_quantized))
                    return type(arg)(loaded_arg)
            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception("partially quantized inputs in list not handled yet")

        for node in observed_graph.nodes:
            root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is node:
                result = obj.convert(self, node, load_arg)
                quantized = True
                # Need to get correct quantized/non-quantized state for the output of CopyNode
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'
                    quantized = is_quantized(node.args[0])

                # output of dynamic quantization is not quantized
                if self.is_dynamic_quant:
                    quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if node.target.split('.')[-1].startswith('activation_post_process_'):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    parent_name = ''

                    scale, zero_point = observer_module.calculate_qparams()
                    dtype = observer_module.dtype

                    def is_per_channel(qscheme):
                        return qscheme == torch.per_channel_affine or \
                            qscheme == torch.per_channel_symmetric

                    if is_per_channel(observer_module.qscheme):
                        ch_axis = int(observer_module.ch_axis)
                        qparams = {'_scale_': scale, '_zero_point_': zero_point, '_axis': ch_axis, '_dtype_': dtype}
                        quantize_op = torch.quantize_per_channel
                    else:
                        scale = float(scale)
                        zero_point = int(zero_point)
                        qparams = {'_scale_': scale, '_zero_point_': zero_point, '_dtype_': dtype}
                        quantize_op = torch.quantize_per_tensor
                    i = 0

                    def noattr(module, qparams, i):
                        for name in qparams.keys():
                            if hasattr(module, name + str(i)):
                                return False
                        return True

                    def get_next_i(module, qparams):
                        i = 0
                        while not noattr(module, qparams, i):
                            i += 1
                        return i

                    parent_module = self.modules[parent_name]
                    i = get_next_i(parent_module, qparams)
                    inputs = [load_non_quantized(node.args[0])]
                    for key, value in qparams.items():
                        setattr(parent_module, key + str(i), value)
                        qparam_full_path = key + str(i)
                        if parent_name:
                            qparam_full_path = parent_name + '.' + qparam_full_path
                        inputs.append(self.quantized_graph.create_node('get_param', qparam_full_path))
                    quant_env[node.name] = self.quantized_graph.create_node('call_function', quantize_op, tuple(inputs), {})
                    continue
            # dequantize inputs for the node that are not quantized
            env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)

        self.quantized_graph.output(load_non_quantized(observed_graph.result))

        to_be_removed = []
        for name, _ in observed_root.named_modules():
            if name.split('.')[-1].startswith('activation_post_process_'):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(observed_root, n)
        return GraphModule(observed_root, self.quantized_graph)
示例#16
0
def add_loggers_to_model(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(
                map_arg(_get_normalized_nth_input(node, gm, 0), load_arg))
            continue

        if ((node in node_to_instrument_inputs_to_ref_node_name)
                or (node in node_to_instrument_outputs_to_ref_node_name)):
            fqn = _maybe_get_fqn(node, gm)

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[
                    node]
                # Ops such add and mul are special because either
                # one or two of the first two arguments can be tensors,
                # and if one argument is a tensor it can be first or
                # second (x + 1 versus 1 + x).
                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
                for node_arg_idx in arg_indices_to_log:
                    node_arg = _get_normalized_nth_input(
                        node, gm, node_arg_idx)
                    if type(node_arg) == Node:
                        # create a single input logger
                        prev_node = env[node_arg.name]
                        env[node_arg.name] = _insert_logger_after_node(
                            prev_node,
                            gm,
                            logger_cls,
                            '_ns_logger_',
                            node.name,
                            model_name,
                            ref_name,
                            ref_node_type,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0,
                            index_of_arg=node_arg_idx,
                            fqn=fqn)
                    elif type(
                            node_arg
                    ) == torch.fx.immutable_collections.immutable_list:
                        # create N input loggers, one for each node
                        for arg_idx, arg in enumerate(
                                node_arg
                        ):  # type: ignore[var-annotated, arg-type]
                            prev_node = env[arg.name]
                            env[prev_node.name] = _insert_logger_after_node(
                                prev_node,
                                gm,
                                logger_cls,
                                '_ns_logger_',
                                node.name,
                                model_name,
                                ref_name,
                                ref_node_type,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=arg_idx,
                                index_of_arg=node_arg_idx,
                                fqn=fqn)
                    else:
                        pass

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[
                    node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name],
                    gm,
                    logger_cls,
                    '_ns_logger_',
                    node.name,
                    model_name,
                    ref_name,
                    ref_node_type,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0,
                    index_of_arg=0,
                    fqn=fqn)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
示例#17
0
def convert(model: GraphModule,
            is_reference: bool = False,
            convert_custom_config_dict: Dict[str, Any] = None,
            is_standalone_module: bool = False,
            _remove_qconfig_flag: bool = True,
            convert_qconfig_dict: Dict[str, Any] = None) -> torch.nn.Module:
    """ standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    patterns, node_name_to_scope, prepare_custom_config_dict, _ = restore_state(
        model)
    qconfig_map: Dict[
        str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    # TODO this should be removed now that gpu support for quantization is being supported.
    # however in practice, as of 7/22/2021, certain functions that get called by convert expect
    # only cpu arguments.
    # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda',
    # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend.
    if not is_reference:
        model.cpu()

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    # TODO refactor this code once we update the prepare logic to have additional information on
    # which graph nodes have been observed and share that with convert to decide which observers to ignore.
    if convert_qconfig_dict:
        prepare_qconfig_dict: Dict[str, Dict[
            Any, Any]] = model._qconfig_dict  # type: ignore[assignment]
        modules_copy = copy.deepcopy(modules)
        convert_dict_to_ordered_dict(convert_qconfig_dict)
        if model._is_qat:
            additional_qat_module_mapping = prepare_custom_config_dict.get(
                "additional_qat_module_mapping", {})
            convert_qconfig_dict = update_qconfig_for_qat(
                convert_qconfig_dict, additional_qat_module_mapping)
        convert_qconfig_dict = update_qconfig_for_fusion(
            model, convert_qconfig_dict)

        compare_prepare_convert_qconfig_dict(
            prepare_qconfig_dict,
            convert_qconfig_dict)  # type: ignore[arg-type]
        convert_qconfig_map = generate_qconfig_map(model, modules_copy,
                                                   model.graph,
                                                   convert_qconfig_dict,
                                                   node_name_to_scope)
        # check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map
        # or are set to None in the convert_qconfig_map.
        for k, v in qconfig_map.items():
            assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(
                k)
            if convert_qconfig_map[k] is not None:
                assert qconfig_equals(
                    v, convert_qconfig_map[k]
                ), 'Expected k {} to have the same value in prepare qconfig_dict \
                and convert qconfig_dict, found {} updated to {}.'.format(
                    k, v, convert_qconfig_map[k])
        qconfig_map = convert_qconfig_map

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    matches = find_matches(model.graph,
                           modules,
                           patterns,
                           qconfig_map,
                           custom_module_classes=custom_module_classes)

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    # always run weight observers in the top level forward method
    # for dynamic quant ops or weight only quant ops
    run_weight_observers(model)

    quantized_graph = Graph()
    env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict(
        lambda: defaultdict(Node))  # type: ignore[arg-type]

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    def load_non_quantized(n: Node) -> Node:
        assert n.name in env, \
            'trying to load float node but did not find ' + \
            'node:' + n.name + \
            ' in env: ' + \
            str(env)
        dtype_to_node = env[n.name]
        if torch.float in dtype_to_node:
            return dtype_to_node[torch.float]
        elif None in dtype_to_node:
            return dtype_to_node[None]
        else:
            quantized_node = None
            for dtype in [torch.quint8, torch.qint8, torch.float16]:
                if dtype in dtype_to_node:
                    quantized_node = dtype_to_node[dtype]
                    break
            assert quantized_node is not None, "Did not find a supported quantized dtype:{}".format(
                dtype_to_node)
            env[n.name][torch.float] = Proxy(quantized_node).dequantize().node
            return env[n.name][torch.float]

    def load_quantized(dtype: torch.dtype):
        def load_quantized_impl(n: Node):
            assert n.name in env, \
                'trying to load quantized node but did not find node:' + \
                n.name + ' in environment:' + str(env)
            dtype_to_node = env[n.name]
            local_dtype: Optional[torch.dtype] = dtype
            if local_dtype == torch.float and local_dtype not in dtype_to_node:
                local_dtype = None
            if local_dtype in [torch.float, None]:
                return load_non_quantized(n)
            assert local_dtype in dtype_to_node, f'Expecting {dtype} in {dtype_to_node}'
            return dtype_to_node[local_dtype]

        return load_quantized_impl

    def load_x(n: Node) -> Node:
        assert n.name in env, \
            'node ' + n.name + ' does not exist in environment'
        dtype_to_node = env[n.name]
        dtypes = [
            torch.quint8, torch.qint8, torch.float16, torch.float32, None
        ]
        for dtype in dtypes:
            if dtype in dtype_to_node:
                return dtype_to_node[dtype]
        raise Exception(
            f'dtype {dtype} not found in environment: {dtype_to_node} for node {n.name}'
        )

    def load_arg(
        quantized: Optional[Union[List[int], Dict[int, torch.dtype],
                                  torch.dtype, Tuple[int, ...]]]
    ) -> Callable[[Node], Argument]:
        """
        Input: quantized, which can be None, torch.dtype, list or tuple
          - if quantized is None, then we'll load the node as long as it
            exists
          - if quantized is a dtype, then all args will be
            quantized to the specific dtype
          - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=torch.float)
          - if quantized is a list or tuple, then arg should be a list and
            the args with corresponding indexes will be quantized to torch.quint8


        Output: fn which takes arg_or_args, and loads them from the
            corresponding environment depending on the value of quantized.
        """
        assert quantized is None or \
            isinstance(quantized, (tuple, list, dict, torch.dtype)), type(quantized)
        if isinstance(quantized, (tuple, list, dict)) and len(quantized) == 0:
            # empty tuple or list means nothing is quantized
            quantized = torch.float

        def load_arg_impl(arg_or_args):
            # we'll update the format of `quantized`
            # to better match arg_or_args
            updated_quantized: Optional[Union[List[int], torch.dtype,
                                              Dict[int, torch.dtype],
                                              Tuple[int, ...]]] = quantized

            if isinstance(quantized, (tuple, list)) and \
               len(quantized) == 1 and isinstance(arg_or_args, Node):
                # when argument is one Node instead of tuple, we just need to check
                # 0 is in the quantized list
                if 0 in quantized:
                    updated_quantized = torch.quint8

            if updated_quantized is None:
                return map_arg(arg_or_args, load_x)
            if isinstance(updated_quantized, torch.dtype):
                return map_arg(arg_or_args, load_quantized(updated_quantized))
            elif isinstance(updated_quantized, (tuple, list)):
                assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                loaded_args = []
                # for now, we only support quantizing positional arguments
                for i, a in enumerate(arg_or_args):
                    if i in updated_quantized:
                        # Currently it's hardcoded to torch.quint8, we can extend this
                        # in the future to support all quantized
                        # dtypes
                        loaded_args.append(
                            map_arg(a, load_quantized(torch.quint8)))
                    else:
                        loaded_args.append(map_arg(a, load_non_quantized))
                return type(arg_or_args)(loaded_args)
            elif isinstance(updated_quantized, dict):
                loaded_args = []
                for i, a in enumerate(arg_or_args):
                    if i in updated_quantized:
                        loaded_args.append(
                            map_arg(a, load_quantized(updated_quantized[i])))
                    else:
                        loaded_args.append(map_arg(a, load_non_quantized))
                return type(arg_or_args)(loaded_args)

        return load_arg_impl

    def node_arg_is_quantized(node_arg: Any) -> bool:
        if isinstance(node_arg, Node):
            assert node_arg.name in env, \
                'Expecting node_arg to be in the environment'
            if node_arg.name in env:
                dtype_to_node = env[node_arg.name]
                return any([
                    x in dtype_to_node
                    for x in [torch.quint8, torch.qint8, torch.float16]
                ])
            else:
                return False
        elif isinstance(node_arg, list):
            quantized = map(node_arg_is_quantized, node_arg)
            if all(quantized):
                return True
            elif not any(quantized):
                return False
            else:
                raise Exception(
                    "partially quantized inputs in list not handled yet")
        else:
            return False

    def is_output_quantized(node: Node, obj: QuantizeHandler,
                            qconfig: QConfigAny,
                            modules: Dict[str, torch.nn.Module]) -> bool:
        """ Check if output node is quantized or not """
        assert modules is not None
        # for some ops the output is quantized only when `is_reference` is True
        # and when `is_reference` is False, it has limited qconfig
        # support, for example `add`
        # ideally this check should not happen here, it should happen either in
        # prepare or during lowering, we don't need this check
        # after the default path is changed to produce reference patterns
        quantized = obj.is_output_quantized(qconfig)

        # Need to get correct quantized/non-quantized state forn the output
        # of FixedQParamsQuantizeHandler
        # TODO: we may want to try to remove the special case here
        # as well
        if obj.should_mark_output_quantized_from_input_quantized_status(
                qconfig):
            assert node.op in [
                'call_module',
                'call_function',
                'call_method'], \
                'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
            # TODO: need to extend this to consider all relevant args instead of just arg[0]
            quantized = node_arg_is_quantized(node.args[0])

        # the output is unquantized if the node is not a CopyNode
        # or the activation is not statically quantized
        if not activation_is_statically_quantized(qconfig) or \
           not obj.input_output_observed():
            quantized = False
        if node_return_type_is_int(node):
            quantized = False

        return quantized

    def insert_quantize_node(node: Node,
                             modules: Dict[str, torch.nn.Module]) -> None:
        """ Given a activation_post_process module call node, insert a
        quantize node"""
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        prev_node = node.args[0]
        if observer_module.dtype == torch.float32:
            # copy the observer for fp32 dtype
            env[node.name][torch.float] = quantized_graph.node_copy(
                node, load_non_quantized)
        elif isinstance(prev_node, Node) and prev_node.name in env:
            # if previous node is already quantized, we'll just remove the
            # activation_post_process
            prev_dtype_to_node: Dict[Optional[torch.dtype],
                                     Node] = env[prev_node.name]
            current_dtype: Optional[
                torch.
                dtype] = observer_module.dtype  # type: ignore[assignment]
            if current_dtype in prev_dtype_to_node:
                env[node.
                    name][current_dtype] = prev_dtype_to_node[current_dtype]
            else:
                root_module = modules[""]
                assert isinstance(prev_node, Node)
                observer_dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
                env[node.name][observer_dtype] = \
                    quantize_node(
                        load_non_quantized(prev_node),
                        observer_module, node, modules, quantized_graph,
                        node_name_to_scope, is_input=True)
        else:
            # replace activation post process with quantization ops
            root_module = modules[""]
            assert isinstance(node.args[0], Node)
            dtype: torch.dtype = observer_module.dtype  # type: ignore[assignment]
            env[node.name][dtype] = \
                quantize_node(
                    load_non_quantized(node.args[0]),
                    observer_module, node, modules,
                    quantized_graph,
                    node_name_to_scope, is_input=True)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    output_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    for node in model.graph.nodes:
        if node.op == "output":
            cur_output_node_idx = output_node_seen_cnt
            output_node_seen_cnt += 1
            if cur_output_node_idx in output_quantized_idxs:
                # Result are kept quantized if the user specified the
                # output_quantized_idxs override.
                graph_output = map_arg(node.args[0], load_x)
            else:
                graph_output = map_arg(node.args[0], load_non_quantized)
            quantized_graph.output(graph_output)
            continue
        root_node, matched, matched_pattern, obj, qconfig = \
            matches.get(node.name, (None, None, None, None, None))
        if root_node is node:
            is_observed_standalone_module_node = (
                node.op == 'call_module'
                and is_observed_standalone_module(modules[node.target]))
            if qconfig is None and not is_observed_standalone_module_node:
                result = quantized_graph.node_copy(node, load_non_quantized)
                quantized = False
                # If there are QAT swapped modules in the graph that we don't want to quantize, rever them back to FP32 ones.
                if node.op == 'call_module' and type(modules[
                        node.target]) in DEFAULT_QAT_MODULE_MAPPINGS.values():
                    float_mod = modules[node.target].to_float()
                    setattr(model, node.name, float_mod)
                    with model.graph.inserting_before(node):
                        new_float_node = model.graph.create_node(
                            'call_module', node.name, node.args, node.kwargs)
            else:
                assert obj is not None
                # We will get whether the output is quantized or not before
                # convert for standalone module and after convert
                # for non-standalone module, since _standalone_module_output_quantized_idxs
                # is only available in observed standalone module
                if is_observed_standalone_module_node:
                    out_quant_idxs = modules[
                        node.
                        target]._standalone_module_output_quantized_idxs.tolist(
                        )  # noqa: B950
                    assert len(
                        out_quant_idxs
                    ) <= 1, "Currently standalone only support one output"
                    quantized = 0 in out_quant_idxs

                qconfig = qconfig_map[node.name]
                # Note: load_arg can be overwritten in the convert method when used to
                # create Node in graph
                result = obj.convert(
                    node,
                    qconfig,
                    modules,
                    quantized_graph,
                    node_name_to_scope,
                    load_arg,
                    is_reference=is_reference,
                    convert_custom_config_dict=convert_custom_config_dict)
                if not is_observed_standalone_module_node:
                    quantized = is_output_quantized(node, obj, qconfig,
                                                    modules)

            if quantized:
                env[node.name][activation_dtype(qconfig)] = result
            else:
                env[node.name][torch.float] = result
            continue
        elif root_node is not None:
            if qconfig is None:
                # This branch is hit if all of these conditions are met:
                # 1. we are in a fusion pattern of multiple nodes (i.e. add-relu)
                # 2. the current node is not the "root_node" of the pattern
                # 3. quantization for this pattern is disabled
                #
                # In this case, we need to make sure to populate the env with
                # intermediate nodes manually, because the QuantizeHandler.convert
                # function will not be called.
                result = quantized_graph.node_copy(node, load_non_quantized)
                env[node.name][torch.float] = result
            continue

        # handle activation post process calls
        if node.op == 'call_module' and \
                is_activation_post_process(modules[node.target]):
            insert_quantize_node(node, modules)
        elif node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                env[node.name][torch.quint8] = quantized_graph.node_copy(
                    node, load_non_quantized)
            else:
                env[node.name][torch.float] = \
                    quantized_graph.node_copy(node, load_non_quantized)
        else:
            # copy quantized or non-quantized node
            # get_tensor_info_node like shape works for both
            # quantized and non-quantized input and output a non-Tensor
            # (we use None for dtype currently for non-Tensors)
            if is_get_tensor_info_node(node):
                env[node.name][None] = \
                    quantized_graph.node_copy(node, load_x)
            else:
                env[node.name][torch.float] = \
                    quantized_graph.node_copy(node, load_non_quantized)

    # remove activation post process
    act_post_process_removed_graph = Graph()
    remove_env: Dict[str, Node] = {}

    def load_arg_remove(a: Argument) -> Argument:
        return map_arg(a, lambda node: remove_env[node.name])

    for node in quantized_graph.nodes:
        if node.op == 'output':
            act_post_process_removed_graph.output(
                map_arg(node.args[0], load_arg_remove))
            continue
        if node.op == 'call_module' and \
           is_activation_post_process(modules[node.target]):
            # remove activation post process node
            remove_env[node.name] = remove_env[node.args[0].name]
        else:
            remove_env[node.name] = act_post_process_removed_graph.node_copy(
                node, load_arg_remove)

    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    preserved_attributes = set(
        convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, act_post_process_removed_graph,
                                 preserved_attributes)
    if not is_reference:
        model = duplicate_dequantize_node(model)
        model = fold_weight(model, node_name_to_scope)
        model = lower_to_fbgemm(model)
        model = remove_quant_dequant_pairs(model)
        model = remove_extra_dequantize(model)
    return model
示例#18
0
    def test_type_check_conv2D_2_fully_static(self):
        annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                           (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)]
        input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                      (10, 15, 13, 14), (1, 2, 2, 3)]
        intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6),
                              (10, 15, Dyn, 5), (10, 15, 7, 7),
                              (1, Dyn, Dyn, Dyn)]
        in_planes_list = [2, 5, 15, 15, 2]
        stride_list = [1, 2, 3, 2, 2]
        out_planes_list = [2, 5, 15, 15, 2]
        groups_list = [1, 5, 5, 5, 2]
        dilation_list = [1, 2, 3, 3, 3]
        padding_list = [1, 2, 3, 3, 3]
        kernel_size_list = [1, 2, 3, 3, 3]
        output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5),
                        (10, 15, 7, 7), (1, 2, Dyn, Dyn)]

        for i in range(5):
            annotation = annotation_list[i]
            input = input_list[i]
            in_planes = in_planes_list[i]
            stride = stride_list[i]
            out_planes = out_planes_list[i]
            groups = groups_list[i]
            dilation = dilation_list[i]
            padding = padding_list[i]
            kernel_size = kernel_size_list[i]
            intermediate_type = intermediate_types[i]

            class BasicBlock(torch.nn.Module):
                def __init__(self, in_planes, out_planes, kernel_size, stride,
                             padding, groups, dilation):
                    super(BasicBlock, self).__init__()
                    self.conv1 = torch.nn.Conv2d(in_channels=in_planes,
                                                 out_channels=out_planes,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding,
                                                 groups=groups,
                                                 bias=False,
                                                 dilation=dilation)

                def forward(self, x):
                    out = self.conv1(x)
                    return out

            B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding,
                           groups, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # annotate our argument
            for n in graph.nodes:
                if n.op == 'placeholder':
                    n.type = TensorType(annotation)

            b = B.forward(torch.rand(input))
            tc = GraphTypeChecker({}, traced)
            tc.type_check()

            for n in graph.nodes:
                if n.op == 'output':
                    assert is_consistent(n.type, TensorType(b.size()))

            # test with intermediate annotations
            class BasicBlock(torch.nn.Module):
                def __init__(self, in_planes, out_planes, kernel_size, stride,
                             padding, groups, dilation):
                    super(BasicBlock, self).__init__()
                    self.conv1 = torch.nn.Conv2d(in_channels=in_planes,
                                                 out_channels=out_planes,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding,
                                                 groups=groups,
                                                 bias=False,
                                                 dilation=dilation)

                def forward(self, x):
                    out = self.conv1(x)
                    return out

            B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding,
                           groups, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # populate our intermediate notes
            for n in traced.graph.nodes:
                if n.op == 'call_module':
                    n.type = TensorType(intermediate_type)

            tc = GraphTypeChecker({}, traced)
            tc.type_check()

            for n in traced.graph.nodes:
                if n.op == 'output':
                    assert n.type == TensorType(output_types[i])
                    assert is_consistent(n.type, TensorType(b.size()))
示例#19
0
def min_cut_rematerialization_partition(
        joint_module: fx.GraphModule,
        _joint_inputs,
        compiler="nvfuser") -> Tuple[fx.GraphModule, fx.GraphModule]:
    """
    Partitions the joint graph such that the backward recomputes the forward.
    Recomputing helps in trading off memory bandwidth with computation.

    To create the fwd and bwd graph, we copy the joint graph, manually set the
    outputs to just original forward or backward outputs. And then we run the
    resulting graphs through dead code elimintation.

    .. warning::
        This API is experimental and likely to change.

    Args:
        joint_module(fx.GraphModule): The joint forward and backward graph. This
            is the result of AOT Autograd tracing.

    Returns:
        Returns the generated forward and backward Fx graph modules.
    """
    try:
        import networkx as nx
    except ImportError:
        raise RuntimeError(
            "Need networkx installed to perform smart recomputation heuristics"
        )

    joint_module.graph.eliminate_dead_code()
    joint_module.recompile()
    fx_g = joint_module.graph

    #  add the CSE pass
    cse_graph = fx_graph_cse(fx_g)
    joint_module.graph = cse_graph
    full_bw_graph = joint_module.graph

    name_to_node = {}
    for node in joint_module.graph.nodes:
        name_to_node[node.name] = node

    def classify_nodes(joint_module):
        required_bw_nodes = set()
        for node in joint_module.graph.nodes:
            if node.op == 'placeholder' and "tangents" in node.target:
                required_bw_nodes.add(node)
            if node in required_bw_nodes:
                for user in node.users:
                    required_bw_nodes.add(user)

        primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
        fwd_outputs, _ = _extract_fwd_bwd_outputs(joint_module)
        forward_only_graph = _extract_graph_with_inputs_outputs(
            joint_module.graph, primal_inputs, fwd_outputs)
        required_fw_nodes = {
            name_to_node[node.name]
            for node in forward_only_graph.nodes if node.op != 'output'
        }
        unclaimed_nodes = {
            node
            for node in joint_module.graph.nodes
            if node not in required_fw_nodes and node not in required_bw_nodes
        }
        return required_fw_nodes, required_bw_nodes, unclaimed_nodes

    required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(
        joint_module)
    for node in reversed(joint_module.graph.nodes):
        if node not in required_fw_nodes:
            node.dist_from_bw = 0
        else:
            node.dist_from_bw = int(1e9)
            for user in node.users:
                node.dist_from_bw = min(node.dist_from_bw,
                                        user.dist_from_bw + 1)

    aten = torch.ops.aten
    prims = torch.ops.prims

    pointwise_ops = [
        aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min,
        aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__,
        aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne,
        aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not,
        aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round,
        aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2,
        aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos,
        aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan,
        aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt,
        aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold,
        aten.threshold_backward, aten.clamp, aten.where, aten.lerp,
        aten.addcmul, aten.gelu, aten.gelu_backward
    ]  # noqa: E501
    if compiler == "inductor":
        pointwise_ops += [
            prims.div, prims.convert_element_type, aten.sign, aten.clone
        ]  # noqa: E501
    misc_ops = [aten.to, aten.type_as, operator.getitem]

    reduction_ops = [
        aten.softmax, aten._softmax, aten._softmax_backward_data, aten.sum,
        aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax
    ]  # noqa: E501
    if compiler == "inductor":
        reduction_ops += [prims.var, prims.sum, aten.var]

    # not recomputed by default since these are kinda expensive/hard to fuse into
    # norm_ops = [aten.instance_norm, aten._batch_norm_impl_index, aten.native_batch_norm, aten.batch_norm, aten._batch_norm_impl_index_backward, aten.native_layer_norm, aten.layer_norm, aten.native_layer_norm_backward]  # noqa: E501

    # Not used by default since NVFuser can't fuse view ops
    # view_ops = [aten.expand, aten.clone, aten.transpose, aten.t, aten.view, aten._unsafe_view, aten.permute, aten.transpose, aten.t, aten._reshape_alias, aten.squeeze, aten.unsqueeze, aten.reshape, aten.cat, aten.slice, aten.split, aten.select, aten.repeat]  # noqa: E501

    # These are the view ops that NVFuser can fuse
    view_ops = [aten.squeeze, aten.unsqueeze]
    if compiler == "inductor":
        view_ops += [
            prims.broadcast_in_dim, aten.select, aten.permute,
            aten._unsafe_view, aten.view, aten.expand, aten.slice,
            aten.reshape, aten.broadcast_tensors
        ]  # noqa: E501
    random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
    compute_intensive_ops = [
        aten.mm, aten.convolution, aten.convolution_backward, aten.bmm,
        aten.addmm, aten.upsample_bilinear2d
    ]  # noqa: E501

    unrecomputable_ops = random_ops + compute_intensive_ops

    recomputable_ops = set(pointwise_ops + misc_ops + reduction_ops + view_ops)
    fusible_ops = recomputable_ops | set(random_ops)
    if AOT_PARTITIONER_DEBUG:
        joint_module_ops = set(
            str(node.target._overloadpacket)
            for node in joint_module.graph.nodes if node.op == "call_function"
            and hasattr(node.target, "_overloadpacket"))
        ops_ignored = joint_module_ops - set(
            [str(i) for i in recomputable_ops])
        print("Ops banned from rematerialization: ", ops_ignored)
        print()

    AGGRESSIVE_RECOMPUTATION = False

    def _maybe_size_of(node):
        if 'tensor_meta' in node.meta:
            return _size_of(node.meta['tensor_meta'])
        return 0

    def ban_recomputation(node):
        if AGGRESSIVE_RECOMPUTATION:
            return (node.op == 'call_function'
                    and get_aten_target(node) in unrecomputable_ops)
        else:
            if node.op != 'call_function':
                return False
            if get_aten_target(node) not in recomputable_ops:
                return True
            if node.target == operator.getitem:
                return False
            if compiler == "inductor" and node.dist_from_bw > 4:
                return True
            # If the output of an op is 4x smaller (arbitrary choice),
            # then we don't allow recomputation.
            if 'tensor_meta' not in node.meta:
                return False
            input_tensors_size = sum(
                _maybe_size_of(i) for i in node.args if isinstance(i, fx.Node))
            output_size = _size_of(node.meta['tensor_meta'])
            return (output_size * 4 < input_tensors_size)

    def is_fusible(a, b):
        return get_aten_target(a) in fusible_ops and get_aten_target(
            b) in fusible_ops

    def is_materialized(node):
        if node.op == 'placeholder':
            return True

        return not all(is_fusible(node, user) for user in node.users)

    def get_node_weight(node):
        mem_sz = _size_of(node.meta['tensor_meta'])

        # Heuristic to bias towards nodes closer to the backwards pass
        # Complete guess about current value
        mem_sz = int(mem_sz * (1.1**max(min(node.dist_from_bw, 100), 1)))
        # mem_sz = int(mem_sz + node.dist_from_bw)

        if is_materialized(node):
            return mem_sz
        else:
            return mem_sz * 2

    nx_graph = nx.DiGraph()
    for node in full_bw_graph.nodes:
        if node.op == 'output':
            continue

        if node in required_bw_nodes:
            nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
            continue

        if node.op == 'placeholder' and "primals" in node.target:
            nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)

        # If a node can't be recomputed (too expensive or involves randomness),
        # we prevent it from being recomputed by adding an inf edge to the source
        # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.
        if ban_recomputation(node) and node in required_fw_nodes:
            nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)

        if 'tensor_meta' not in node.meta:
            weight = math.inf
        else:
            weight = get_node_weight(node)

        # Creates the weights on the "node" edge
        nx_graph.add_edge(node.name + "_in",
                          node.name + "_out",
                          capacity=weight)
        for user in node.users:
            nx_graph.add_edge(node.name + "_out",
                              user.name + "_in",
                              capacity=math.inf)

    cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
    reachable, non_reachable = partition
    cutset = set()
    for u, nbrs in ((n, nx_graph[n]) for n in reachable):
        cutset.update((u, v) for v in nbrs if v in non_reachable)

    cut_nodes = set()
    for node_in, node_out in cutset:
        assert node_in[:-3] == node_out[:-4]
        node_name = node_in[:-3]
        cut_nodes.add(node_name)

    # To make this stuff deterministic
    node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
    saved_values = sorted((name_to_node[node] for node in cut_nodes),
                          key=lambda x: node_idx[x])
    fw_module, bw_module = _extract_fwd_bwd_modules(joint_module, saved_values)
    if AOT_PARTITIONER_DEBUG:
        print(
            "Theoretical Activations Stored: ",
            sum([_size_of(i.meta['tensor_meta']) for i in saved_values]) / 1e9)
        fw_module_nodes = set([
            node.name for node in fw_module.graph.nodes
            if node.op == 'call_function'
        ])
        bw_module_nodes = set([
            node.name for node in bw_module.graph.nodes
            if node.op == 'call_function'
        ])
        remat_nodes = fw_module_nodes & bw_module_nodes

        counts = defaultdict(int)
        for node in fw_module.graph.nodes:
            if node.name in remat_nodes and hasattr(node.target,
                                                    '_overloadpacket'):
                counts[str(node.target._overloadpacket)] += 1
        print("# nodes rematerialized: ", len(remat_nodes))
        print("Count of Ops Rematerialized: ",
              sorted(counts.items(), key=lambda x: x[1], reverse=True))
    return fw_module, bw_module
示例#20
0
    def test_typecheck_basicblock(self):
        class BasicBlock(torch.nn.Module):
            expansion = 1

            def __init__(self,
                         inplanes,
                         planes,
                         stride=1,
                         downsample=None,
                         groups=1,
                         base_width=64,
                         dilation=1):
                super(BasicBlock, self).__init__()
                norm_layer = torch.nn.BatchNorm2d
                if groups != 1 or base_width != 64:
                    raise ValueError(
                        'BasicBlock only supports groups=1 and base_width=64')
                if dilation > 1:
                    raise NotImplementedError(
                        "Dilation > 1 not supported in BasicBlock")
                # Both self.conv1 and self.downsample layers downsample the input when stride != 1
                self.conv1 = conv3x3(inplanes, planes, stride)
                self.bn1 = norm_layer(planes)
                self.relu = torch.nn.ReLU(inplace=True)
                self.conv2 = conv3x3(planes, planes)
                self.bn2 = norm_layer(planes)
                self.downsample = downsample
                self.stride = stride

            def forward(self, x: TensorType((2, 2, 4, 5))):
                identity = x

                out = self.conv1(x)
                out = self.bn1(out)
                out = self.relu(out)

                out = self.conv2(out)
                out = self.bn2(out)

                if self.downsample is not None:
                    identity = self.downsample(x)

                out += identity
                out = self.relu(out)

                return out

        B = BasicBlock(2, 2)

        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")

        tc = GraphTypeChecker({}, traced)
        tc.type_check()

        for n in traced.graph.nodes:
            if n.target == 'output':
                assert isinstance(n.type, TensorType)
                assert torch.Size(n.type.__args__) == B.forward(
                    torch.rand(2, 2, 4, 5)).size()
示例#21
0
    def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant):
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        if self.is_dynamic_quant:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        propagate_qconfig_(model, qconfig_dict)
        if model.training:
            self._qat_swap_modules(model)

        self.modules = dict(model.named_modules())

        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph)

        # match the patterns that will get quantized
        matches = self._find_matches(model.graph, self.modules, self.patterns)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()

        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in model.graph.nodes:
            if node.name in observed_node_names_set:
                continue

            prefix = node.name + '_activation_post_process_'
            root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                if qconfig is None:
                    continue

                def insert_observer(node, observer, device):
                    get_new_observer_name = get_new_attr_name_with_prefix(prefix)
                    observer_name = get_new_observer_name(model)
                    setattr(model, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed_node_names_set.add(node.name)
                    if device:
                        getattr(model, observer_name).to(device)

                if isinstance(obj, CustomModuleQuantizeHandler):
                    custom_module = self.modules[node.target]
                    observed_custom_module_class = \
                        get_observed_custom_module_class(type(custom_module))
                    observed_custom_module = \
                        observed_custom_module_class.from_float(custom_module)
                    mark_observed_custom_module(observed_custom_module, type(custom_module))
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name, observed_custom_module)

                # don't need to insert observer for output in dynamic quantization
                if self.is_dynamic_quant:
                    continue

                # inserting observers for output of observed module, or mark the output
                # as observed
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))
                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed_node_names_set:
                        observed_node_names_set.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    insert_observer(node, new_observer, device)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed_node_names_set and node.name in quants:
                get_new_observer_name = get_new_attr_name_with_prefix(prefix)
                observer_name = get_new_observer_name(model)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    new_observer = \
                        qconfig.weight() if is_weight else qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    if device:
                        new_observer.to(device)
                    self.activation_post_process_map[node.name] = new_observer
                    setattr(model, observer_name, self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed_node_names_set.add(node.name)
        observed_graph.output(load_arg(model.graph.result))

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        return model
示例#22
0
    def test_type_maxpool2d_fully_static(self):
        annotation_list = [(Dyn, Dyn, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                           (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 10)]
        input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                      (10, 15, 13, 14), (2, 2, 10, 10)]
        intermediate_types = [(1, 2, Dyn, Dyn), (2, Dyn, 2, 4),
                              (10, 15, Dyn, 2), (10, 15, 2, 3),
                              (2, Dyn, Dyn, Dyn)]
        stride_list = [1, 2, 3, 2, 1]
        dilation_list = [1, 2, 3, 3, 2]
        padding_list = [1, 2, 3, 3, 1]
        kernel_size_list = [2, 4, 6, 6, 3]
        output_types = [(1, 2, 4, 6), (2, 5, 2, 4), (10, 15, 2, 2),
                        (10, 15, 2, 3), (2, Dyn, Dyn, 8)]

        for i in range(5):
            annotation = annotation_list[i]
            input = input_list[i]
            stride = stride_list[i]
            dilation = dilation_list[i]
            padding = padding_list[i]
            kernel_size = kernel_size_list[i]
            intermediate_type = intermediate_types[i]

            class BasicBlock(torch.nn.Module):
                def __init__(self, kernel_size, stride, padding, dilation):
                    super(BasicBlock, self).__init__()
                    self.pool = torch.nn.MaxPool2d(kernel_size,
                                                   stride=stride,
                                                   padding=padding,
                                                   dilation=dilation,
                                                   return_indices=False,
                                                   ceil_mode=False)

                def forward(self, x):
                    out = self.pool(x)
                    return out

            B = BasicBlock(kernel_size, stride, padding, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # annotate our argument
            for n in graph.nodes:
                if n.op == 'placeholder':
                    n.type = TensorType(annotation)

            b = B.forward(torch.rand(input))
            tc = GraphTypeChecker({}, traced)
            tc.type_check()

            for n in graph.nodes:
                if n.op == 'output':
                    assert is_consistent(n.type, TensorType(b.size()))

            # test with intermediate annotations
            class BasicBlock(torch.nn.Module):
                def __init__(self, kernel_size, stride, padding, dilation):
                    super(BasicBlock, self).__init__()
                    self.pool = torch.nn.MaxPool2d(kernel_size,
                                                   stride=stride,
                                                   padding=padding,
                                                   dilation=dilation,
                                                   return_indices=False,
                                                   ceil_mode=False)

                def forward(self, x):
                    out = self.pool(x)
                    return out

            B = BasicBlock(kernel_size, stride, padding, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # annotate our argument
            for n in graph.nodes:
                if n.op == 'placeholder':
                    n.type = TensorType(annotation)

            # populate our intermediate notes
            for n in traced.graph.nodes:
                if n.op == 'call_module':
                    n.type = TensorType(intermediate_type)

            tc = GraphTypeChecker({}, traced)
            tc.type_check()

            for n in traced.graph.nodes:
                if n.op == 'output':
                    assert n.type == TensorType(output_types[i])
                    assert is_consistent(n.type, TensorType(b.size()))
示例#23
0
    def _convert(self,
                 model,
                 inplace=False,
                 debug=False,
                 is_dynamic_quant=False,
                 is_standalone_module=False):
        """ standalone_module means it a submodule that is not inlined in parent module,
        and will be quantized separately as one unit.
        For standalone module: the inputs will be quantized by parent module,
        checks `_standalone_module_observed_input_idxs` of
        input observed model and will treat these inputs as quantized
        also will not dequantize the final output.
        Returns a quantized standalone module which accepts quantized input(if needed)
        and produces quantized output (if needed).
        """
        self.restore_state(model)
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        # run weight observers before inserting quant dequant nodes
        # for dynamic quantization
        if self.is_dynamic_quant:
            self._run_weight_observers(model)

        # move to cpu since we only have quantized cpu kernels
        model.eval().cpu()
        self.modules = dict(model.named_modules())

        matches = self._find_matches(model.graph, self.modules, self.patterns)

        quants = self._find_quants(model.graph, matches)

        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized or non quantized environment, env: ' + str(env) + \
                    ' quant_env:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and the args with corresponding
                indexes will be quantized
              - if quantized is a boolean, then all args will be quantized/not quantized
              - if quantized is None, then we'll load the node as long as it exists

            Output: fn which takes arg_or_args, and loads them from the corresponding
              environment depending on the value of quantized.
            """
            assert quantized is None or isinstance(
                quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        arg_or_args,
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)

            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception(
                        "partially quantized inputs in list not handled yet")

        for node in model.graph.nodes:
            root_node, matched, obj, qconfig = matches.get(
                node.name, (None, None, None, None))
            if root_node is node:
                if qconfig is None:
                    result = self.quantized_graph.node_copy(
                        node, load_non_quantized)
                    quantized = False
                else:
                    result = obj.convert(self, node, load_arg)
                    if node.op == 'call_module' and is_observed_standalone_module(
                            self.modules[node.target]):
                        quantized = self.modules[
                            node.target]._output_is_observed
                    else:
                        quantized = True

                    # Need to get correct quantized/non-quantized state for the output of CopyNode
                    if isinstance(obj, CopyNode):
                        assert node.op in [
                            'call_module',
                            'call_function',
                            'call_method'], \
                            'CopyNode of type ' + node.op + ' is not handled'
                        quantized = is_quantized(node.args[0])

                    # output of dynamic quantization is not quantized
                    if self.is_dynamic_quant:
                        quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if is_activation_post_process(self.modules[node.target]):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if observer_module.dtype == torch.float16:
                        # activations are not quantized for
                        # fp16 dynamic quantization
                        # copy the activaiton_post_process node here
                        # since we may need it when we insert prepack
                        # op for weight of linear, this will be removed
                        # later in a separate pass
                        env[node.name] = self.quantized_graph.node_copy(
                            node, load_non_quantized)
                        continue
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    root_module = self.modules['']
                    quant_env[node.name] = quantize_node(
                        root_module, self.quantized_graph,
                        load_non_quantized(node.args[0]), observer_module)
                    continue

            if is_standalone_module and node.op == 'placeholder' and \
               graph_inputs.index(node.name) in model._standalone_module_observed_input_idxs:
                # the node is quantized in parent module
                quant_env[node.name] = self.quantized_graph.node_copy(
                    node, load_non_quantized)
            else:
                # dequantize inputs for the node that are not quantized
                env[node.name] = self.quantized_graph.node_copy(
                    node, load_non_quantized)

        if is_standalone_module:
            # result are kepted quantized in the quantized standalone module
            graph_output = map_arg(model.graph.result, load_x)
        else:
            graph_output = map_arg(model.graph.result, load_non_quantized)
        self.quantized_graph.output(graph_output)

        # remove activation post process
        act_post_process_removed_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in self.quantized_graph.nodes:
            if node.op == 'call_module' and \
               is_activation_post_process(self.modules[node.target]):
                # remove activation post process node
                env[node.name] = env[node.args[0].name]
            else:
                env[node.name] = act_post_process_removed_graph.node_copy(
                    node, load_arg)
        act_post_process_removed_graph.output(
            map_arg(self.quantized_graph.result, load_arg))

        module_dict = dict(model.named_modules())
        to_be_removed = []
        for name, module in model.named_modules():
            if is_activation_post_process(
                    module) and not is_submodule_of_fake_quant(
                        name, module, module_dict):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(model, n)
        _remove_qconfig(model)
        model = GraphModule(model, act_post_process_removed_graph)
        return model
示例#24
0
 def transform(traced):
     new_graph = copy.deepcopy(traced.graph)
     relu_out = new_graph.create_node(
         op='call_method', target='neg', args=(new_graph.result,), kwargs={})
     new_graph.output(relu_out)
     return GraphModule(traced, new_graph)
示例#25
0
def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
    logger_cls: Callable,
    should_log_inputs: bool,
    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    if node_type_to_io_type_map is None:
        node_type_to_io_type_map = get_node_type_to_io_type_map()

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    start_node_b_to_matched_subgraph_a_and_name = {}
    end_node_b_to_matched_subgraph_a_and_name = {}
    for match_name, match in matched_subgraph_pairs.items():
        subgraph_a, subgraph_b = match
        start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
            (subgraph_a, match_name)
        end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
            (subgraph_a, match_name)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        # calculate the flags to determine what to do with this node
        node_b_is_observer = \
            node_b.op == 'call_module' and is_activation_post_process(modules[node_b.target])
        node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
        node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name

        if node_b_is_observer:
            # remove activation post process node
            env_c[node_b.name] = env_c[node_b.args[0].name]

        elif (node_b_is_start_node or node_b_is_end_node):

            if node_b_is_start_node:
                subgraph_a, ref_name = \
                    start_node_b_to_matched_subgraph_a_and_name[node_b]
            else:
                assert node_b_is_end_node
                subgraph_a, ref_name = \
                    end_node_b_to_matched_subgraph_a_and_name[node_b]

            # For both start_node and end_node verify that we know how to do
            # the dtype cast. If we do not, skip.
            node_input_type_a, node_output_type_a = \
                get_node_first_input_and_output_type(
                    subgraph_a.start_node, gm_a, logger_cls,
                    node_type_to_io_type_map)
            node_input_type_b, node_output_type_b = \
                get_node_first_input_and_output_type(
                    node_b, gm_b, logger_cls,
                    node_type_to_io_type_map)
            node_io_types_known_a_and_b = (
                node_input_type_a != NodeInputOrOutputType.UNKNOWN and
                node_output_type_a != NodeInputOrOutputType.UNKNOWN and
                node_input_type_b != NodeInputOrOutputType.UNKNOWN and
                node_output_type_b != NodeInputOrOutputType.UNKNOWN
            )
            if not node_io_types_known_a_and_b:
                print(
                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
                    ', unknown dtype cast')
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                continue

            if node_b_is_start_node:

                # if necessary, log the input of node_c
                if should_log_inputs:
                    if isinstance(node_b.args[0], Node):
                        prev_node_c = env_c[node_b.args[0].name]
                        env_c[prev_node_c.name] = _insert_logger_after_node(
                            prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
                            node_b.name, name_b, ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0, index_of_arg=0)
                    elif isinstance(node_b.args[0], list):
                        # first, save the prev_node instances, because they
                        # will be overwritten in the env after the first logger
                        # is added
                        prev_node_c_list = [env_c[arg.name] for arg in node_b.args[0]]

                        for arg_idx, arg in enumerate(node_b.args[0]):
                            prev_node_c = prev_node_c_list[arg_idx]
                            env_c[prev_node_c.name] = _insert_logger_after_node(
                                prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
                                node_b.name, name_b, ref_name,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=arg_idx, index_of_arg=0)
                    else:
                        # logging of inputs which are not lists is not supported yet
                        raise AssertionError(f"type {type(node_b.args[0])} is not handled yet")
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)?

            # Note: this if statement is always True, spelling it out to clarify code
            # intent.
            if node_b_is_start_node or node_b_is_end_node:
                # ensure env_c is populated with base node
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                node_c = env_c[node_b.name]

                # after this point,
                #
                # node_a is the original node from graph_a, with parent module gm_a
                # node_b is the original node from graph_b, with parent module gm_b
                # node_c is the copy of node_b in graph_c
                #
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_start_node:

                # cast dtype from the dtype of node_c's input to the dtype of
                # node_a's input (dequant, etc)
                prev_node_c = node_c.args[0]
                if should_log_inputs:
                    # skip the input logger when inserting a dtype cast
                    if isinstance(prev_node_c, Node):
                        prev_node_c = prev_node_c.args[0]
                    elif isinstance(prev_node_c, list):
                        prev_node_c = [arg.args[0] for arg in prev_node_c]
                dtype_cast_node = _insert_dtype_cast_after_node(
                    subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c,
                    node_b.name + '_dtype_cast_', logger_cls,
                    node_type_to_io_type_map)
                # note: not inserting to env_c because all nodes which use the dtype
                #   casts are copied from graph_a
                #
                # subgraph so far:
                #
                #           (dtype_cast_node)+
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                # if input logging is enabled, log the input to the subgraph
                if should_log_inputs:
                    # TODO: explain this
                    ref_node_name = ''
                    if isinstance(dtype_cast_node, Node):
                        dtype_cast_node = _insert_logger_after_node(
                            dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_',
                            ref_node_name, name_a, ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0, index_of_arg=0)
                        input_logger: Union[Node, List[Node]] = dtype_cast_node
                    else:
                        assert isinstance(dtype_cast_node, list)
                        new_loggers = []
                        for dtype_cast_idx, dtype_cast_node_inner in enumerate(dtype_cast_node):
                            dtype_cast_logger = _insert_logger_after_node(
                                dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_',
                                ref_node_name, name_a, ref_name,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=dtype_cast_idx,
                                index_of_arg=0)
                            new_loggers.append(dtype_cast_logger)
                        dtype_cast_node = new_loggers
                        input_logger = dtype_cast_node
                    # subgraph so far:
                    #
                    #       (dtype_cast_node)+ -> (logger_a_input)?
                    #                  /
                    # prev_node_c -> (logger_c_input)? -> node_start_c

                # hook up the new mod_a copy to be in the graph, receiving the
                # same inputs as mod_b does, with dtype cast to match a
                # Some ops, such as LSTMs, have two non-param inputs. If we have
                # such an op, pass the second param as well. Note: dtype casting
                # for the second param is not implemented yet, it can be added
                # later if there is a use case.
                node_c_second_non_param_arg = None
                num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a)
                if num_non_param_args_node_a == 2:
                    node_c_second_non_param_arg = node_c.args[1]
                node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                    dtype_cast_node, node_c_second_non_param_arg,
                    subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_')
                env_c[node_a_shadows_c.name] = node_a_shadows_c
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                if should_log_inputs:
                    # When we created the input logger, we left the ref_node_name
                    # as an empty string, because the subgraph copy did not exist
                    # yet. Now that the subgraph copy exists, we modify this name
                    # to its true value.
                    # Note: the alternative to this is to create the input logger
                    # after creating the subgraph, which is slightly more
                    # complicated. This is the lesser of two evils.
                    # input_logger = env_c[dtype_cast_node.name]
                    # Find the first node in the subgraph
                    cur_node = node_a_shadows_c
                    while cur_node.args[0] != input_logger:
                        cur_node = cur_node.args[0]  # type: ignore[assignment]
                    if isinstance(input_logger, Node):
                        input_logger_mod = getattr(gm_b, input_logger.name)
                        input_logger_mod.ref_node_name = cur_node.name
                    else:
                        assert isinstance(input_logger, list)
                        for input_logger_inner in input_logger:
                            input_logger_mod = getattr(gm_b, input_logger_inner.name)
                            input_logger_mod.ref_node_name = cur_node.name

                # hook up a logger to the mod_a copy
                env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                    env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_',
                    node_a_shadows_c.name, name_a, ref_name,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0, index_of_arg=0)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_end_node:

                # hook up a logger to the mod_b copy
                env_c[node_b.name] = _insert_logger_after_node(
                    env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_',
                    node_b.name, name_b, ref_name,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0, index_of_arg=0)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
                #
                # Note: node_start_c may be the same node as node_end_c, or they
                # may have nodes inbetween.

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c
示例#26
0
    def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False):
        self.restore_state(model)
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        # run weight observers before inserting quant dequant nodes
        # for dynamic quantization
        if self.is_dynamic_quant:
            self._run_weight_observers(model)

        # move to cpu since we only have quantized cpu kernels
        model.eval().cpu()
        self.modules = dict(model.named_modules())

        matches = self._find_matches(model.graph, self.modules, self.patterns)
        quants = self._find_quants(model.graph, matches)
        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized environment:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and the args with corresponding
                indexes will be quantized
              - if quantized is a boolean, then all args will be quantized/not quantized
              - if quantized is None, then we'll load the node as long as it exists

            Output: fn which takes arg_or_args, and loads them from the corresponding
              environment depending on the value of quantized.
            """
            assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)
            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception("partially quantized inputs in list not handled yet")

        for node in model.graph.nodes:
            root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is node:
                result = obj.convert(self, node, load_arg)
                quantized = True
                # Need to get correct quantized/non-quantized state for the output of CopyNode
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'
                    quantized = is_quantized(node.args[0])

                # output of dynamic quantization is not quantized
                if self.is_dynamic_quant:
                    quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if node.target.split('.')[-1].startswith('activation_post_process_'):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    root_module = self.modules['']
                    quant_env[node.name] = quantize_node(
                        root_module, self.quantized_graph,
                        load_non_quantized(node.args[0]), observer_module)
                    continue
            # dequantize inputs for the node that are not quantized
            env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)

        self.quantized_graph.output(map_arg(model.graph.result, load_non_quantized))

        to_be_removed = []
        for name, _ in model.named_modules():
            if name.split('.')[-1].startswith('activation_post_process_'):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(model, n)
        model = GraphModule(model, self.quantized_graph)
        return model
示例#27
0
def insert_observers_for_model(
    model: GraphModule,
    modules: Dict[str, torch.nn.Module],
    matches: Dict[str, MatchResult],
    qconfig_map: Dict[str, QConfigAny],
    graph: Graph,
    prepare_custom_config_dict: Dict[str, Any],
    input_quantized_idxs: List[int],
    output_quantized_idxs: List[int],
) -> Optional[Node]:
    """
    Inserts observers, using the following high level algorithm:

    For each node in the graph:
      1. determine the target dtype of this node in the quantized graph, and save
           it for future steps
      2. determine the target dtype or all args and kwargs of this node
      3. if any arg or kwarg's target dtype does not match the current node's
           dtype, insert an observer
      4. if the current node needs an output observer, insert it

    For example:

    - starting graph:
        x0 -> linear -> x1

    - observed graph after processing x0:
        x0(fp32)

    - observed graph after processing linear:
        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)

    - observed graph after processing x1:
        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1

    After a node is processed, the naive observer placement is guaranteed to be
    complete for that node and all of its predecessors. There can be future
    passes which optimize the graph by deduplicating observers, etc.
    """

    node_name_to_target_dtype: Dict[str, Any] = {}
    cache_for_no_tensor_check: Dict[Node, bool] = dict()

    inputs_seen_counter = 0
    outputs_seen_counter = 0
    results_node = None

    # first, populate the dtype map based only on qconfig and qhandler
    # this assumes:
    # graph inputs are fp32 by default, and int8 where overriden
    # other nodes output dtype is specified by the qconfig
    modules = dict(model.named_modules(remove_duplicate=False))
    for node in model.graph.nodes:
        root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
            node.name, (None, None, None, None, None))
        node_name_to_target_dtype[
            node.name] = get_target_activation_dtype_for_node(
                node, qconfig, inputs_seen_counter, outputs_seen_counter,
                input_quantized_idxs, output_quantized_idxs, qhandler, modules,
                cache_for_no_tensor_check)

    # Second, for nodes with known input dtypes, propagate them throughout the
    # graph. For example, if there is a call such as
    #   x1 = x0.masked_fill(mask, 1)
    # we propagate the type of mask to be torch.bool
    propagate_dtypes_for_known_nodes(model.graph, node_name_to_target_dtype,
                                     matches)

    # After this point, the current node and all of its arguments
    # have a dtype assigned. Now, we insert observers for inputs
    # of this node (if needed for this node), and the output of this node
    # (if needed for this node).

    # Since we are mutating the graph as we go, we iterate over the original
    # nodes before observer insertion, instead of model.graph.nodes.
    nodes_before_observation = list(model.graph.nodes)

    for node in nodes_before_observation:

        # check for matches
        root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
            node.name, (None, None, None, None, None))

        if node.op == 'placeholder':
            # if a graph input is in fp32, it does not need observation
            # if a graph input is in int8, we assume the observation happens
            #   outside of the graph, and no additional observation is needed
            pass

        elif node.op in ('call_module', 'call_method', 'call_function',
                         'output'):
            modules = dict(model.named_modules(remove_duplicate=False))
            this_node_dtype = node_name_to_target_dtype[node.name]
            output_not_a_tensor = this_node_dtype is None
            # TODO(future PR): consider stopping matching getitem
            is_getitem = node.op == 'call_function' and \
                node.target == operator.getitem

            skip_inserting_observers = (
                (qconfig is None) or output_not_a_tensor
                or is_getitem) and (not node.op == 'output')

            if not skip_inserting_observers:
                if node.op != 'output':
                    # this modifies node inplace
                    maybe_insert_input_observers_for_node(
                        node, qconfig, model, modules, graph,
                        node_name_to_target_dtype, qhandler,
                        prepare_custom_config_dict)

                    is_last_node_of_pattern = root_node is node
                    is_like_copy_node = \
                        (qhandler is not None and (
                            isinstance(qhandler, CopyNodeQuantizeHandler)
                        ))
                    if is_last_node_of_pattern and (not is_like_copy_node):
                        # this returns the new observer node if it was needed
                        maybe_output_obs_node = maybe_insert_output_observer_for_node(
                            node, model, modules, graph, matches,
                            node_name_to_target_dtype, pattern, qhandler)
                        if maybe_output_obs_node is not None:
                            # Update users of original node to use the output observer
                            # instead. For example, change
                            #
                            #           next_node
                            #          /
                            #   cur_node -> obs
                            #
                            # to
                            #
                            #                 next_node
                            #                 /
                            #   cur_node -> obs
                            #
                            # We need to save orig users before updating uses because
                            # the list of users will change as we update uses
                            orig_users = list(node.users.keys())
                            for user_node in orig_users:
                                if user_node is maybe_output_obs_node:
                                    continue
                                user_node.replace_input_with(
                                    node, maybe_output_obs_node)

                            # for quantized cat nodes only, we modify the graph
                            # to make all inputs and outputs use the first input's
                            # observer
                            if isinstance(qhandler, CatQuantizeHandler):
                                adjust_observers_for_cat(node, model, modules)

                            if isinstance(qhandler,
                                          CustomModuleQuantizeHandler):
                                swap_custom_module_to_observed(
                                    node, qconfig, modules,
                                    prepare_custom_config_dict)

                else:  # output
                    maybe_insert_observers_before_graph_output(
                        node, output_quantized_idxs, node_name_to_target_dtype,
                        qconfig_map, model, modules, graph)

        #
        # After this point, the current node has input and output observers
        # that it needs for itself inserted.
        #

        # increment the counters, so future inputs and outputs are assigned
        # correct dtypes
        if node.op == 'placeholder':
            inputs_seen_counter += 1
        elif node.op == 'output':
            outputs_seen_counter += 1
            results_node = node

    return results_node
示例#28
0
    def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant):
        assert not inplace, 'inplace prepare is not supported yet'
        input_root = model.root
        if not inplace:
            input_root = copy.deepcopy(input_root)

        input_graph = model.graph
        self.is_dynamic_quant = is_dynamic_quant
        # TODO: allow user specified patterns
        if self.is_dynamic_quant:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        propagate_qconfig_(input_root, qconfig_dict)
        if input_root.training:
            self._qat_swap_modules(input_root)

        self.modules = dict(input_root.named_modules())

        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(input_root, input_graph)

        # match the patterns that will get quantized
        matches = self._find_matches(input_graph, self.modules, self.patterns)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(input_graph, matches)

        self.activation_post_process_map = dict()

        env = {}
        observed_graph = Graph()
        observed = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in input_graph.nodes:
            if node.name in observed:
                continue

            get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_')
            root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)

                def insert_observer(node, observer):
                    observer_name = get_new_observer_name(input_root)
                    setattr(input_root, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed.add(node.name)

                # don't need to insert observer for output in dynamic quantization
                if self.is_dynamic_quant:
                    continue

                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))
                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed.add(node.name)
                elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed:
                        observed.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    insert_observer(node, qconfig.activation())
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed and node.name in quants:
                observer_name = get_new_observer_name(input_root)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    self.activation_post_process_map[node.name] = qconfig.weight() if is_weight else qconfig.activation()
                    setattr(input_root, observer_name, self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed.add(node.name)
        observed_graph.output(load_arg(input_graph.result))

        observed = GraphModule(input_root, observed_graph)
        self.save_state(observed)
        return observed
示例#29
0
        def lower_to_elementwise_interpreter(
                orig_mod: torch.nn.Module) -> torch.nn.Module:
            # ===== Stage 1: Symbolic trace the module =====
            mod = symbolic_trace(orig_mod)

            # ===== Stage 2: Lower GraphModule representation to the C++
            #       interpreter's instruction format ======
            instructions = []
            constant_idx = 0
            constants = {}
            fn_input_names = []

            target_to_name = {operator.add: "add", operator.mul: "mul"}

            output_node: Optional[Node] = None
            # For each instruction, create a triple
            # (instruction_name : str, inputs : List[str], output : str)
            # to feed into the C++ interpreter
            for n in mod.graph.nodes:
                target, args, out_name = n.target, n.args, n.name
                assert len(n.kwargs) == 0, "kwargs currently not supported"

                if n.op == 'placeholder':
                    # Placeholders specify function argument names. Save these
                    # for later when we generate the wrapper GraphModule
                    fn_input_names.append(target)
                elif n.op == 'call_function':
                    assert target in target_to_name, "Unsupported call target " + target
                    arg_names = []
                    for arg in args:
                        if not isinstance(arg, Node):
                            # Pull out constants. These constants will later be
                            # fed to the interpreter C++ object via add_constant()
                            arg_name = f'constant_{constant_idx}'
                            constants[arg_name] = torch.Tensor(
                                [arg] if isinstance(arg, numbers.Number
                                                    ) else arg)
                            arg_names.append(arg_name)
                            constant_idx += 1
                        else:
                            arg_names.append(arg.name)
                    instructions.append(
                        (target_to_name[target], arg_names, out_name))
                elif n.op == 'output':
                    if output_node is not None:
                        raise RuntimeError('Multiple output nodes!')
                    output_node = n
                else:
                    raise RuntimeError('Unsupported opcode ' + n.op)

            interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter(
            )
            # Load constants
            for k, v in constants.items():
                interpreter.add_constant(k, v)
            # Specify names for positional input arguments
            interpreter.set_input_names(fn_input_names)
            # Load instructions
            interpreter.set_instructions(instructions)
            # Specify name for single output
            assert isinstance(output_node.args[0], torch.fx.Node)
            interpreter.set_output_name(output_node.args[0].name)

            # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
            class WrapperModule(torch.nn.Module):
                def __init__(self, interpreter):
                    super().__init__()
                    self.interpreter = interpreter

            wrapper = WrapperModule(interpreter)

            # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
            # 3) Returns the speficied return value

            # FIXME: The following code could be greatly simplified by symbolic_trace'ing
            # the wrapper with a Tracer that considers the Wrapper instance a root
            # module, however, I can't get `__call__` exposed on TorchBind classes
            # without it messing up Python `hasattr` for some reason. More digging
            # into CPython's implementation of hasattr is probably in order...

            graph = torch.fx.Graph()
            # Add placeholders for fn inputs
            placeholder_nodes = []
            for name in fn_input_names:
                placeholder_nodes.append(graph.create_node(
                    'placeholder', name))

            # Get the interpreter object
            interpreter_node = graph.create_node('get_attr', 'interpreter')

            # Add a node to call the interpreter instance
            output_node = graph.create_node(op='call_method',
                                            target='__call__',
                                            args=(interpreter_node,
                                                  placeholder_nodes))

            # Register output
            graph.output(output_node)

            graph.lint(wrapper)

            # Return final GraphModule!!!
            return GraphModule(wrapper, graph)
示例#30
0
def remove_observers_add_loggers(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, str],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, str],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, removes all observers, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if node.op == 'call_module' and is_activation_post_process(
                modules[node.target]):
            # remove activation post process node
            env[node.name] = env[node.args[0].name]

        elif ((node in node_to_instrument_inputs_to_ref_node_name)
              or (node in node_to_instrument_outputs_to_ref_node_name)):

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name = node_to_instrument_inputs_to_ref_node_name[node]
                if type(node.args[0]) == Node:
                    # create a single input logger
                    prev_node = env[node.args[0].name]
                    env[node.args[0].name] = _insert_logger_after_node(
                        prev_node,
                        gm,
                        logger_cls,
                        '_ns_logger_',
                        node.name,
                        model_name,
                        ref_name,
                        NSSingleResultValuesType.NODE_INPUT.value,
                        index_within_arg=0)
                elif type(node.args[0]
                          ) == torch.fx.immutable_collections.immutable_list:
                    # create N input loggers, one for each node
                    for arg_idx, arg in enumerate(node.args[0]):
                        prev_node = env[arg.name]
                        env[prev_node.name] = _insert_logger_after_node(
                            prev_node,
                            gm,
                            logger_cls,
                            '_ns_logger_',
                            node.name,
                            model_name,
                            ref_name,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=arg_idx)
                else:
                    raise AssertionError(
                        f"type {type(node.args[0])} is not handled yet")

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name = node_to_instrument_outputs_to_ref_node_name[node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name],
                    gm,
                    logger_cls,
                    '_ns_logger_',
                    node.name,
                    model_name,
                    ref_name,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm