def replace_observer_with_dequantize_node(node: Node, graph: Graph): call_custom_module_node = node.args[0] assert isinstance(call_custom_module_node, Node), \ f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" node.replace_all_uses_with(call_custom_module_node) graph.erase_node(node) insert_dequantize_node(call_custom_module_node, graph)
def convert_standalone_module(node: Node, modules: Dict[str, torch.nn.Module], model: torch.fx.GraphModule, is_reference: bool, backend_config_dict: Optional[Dict[str, Any]]): """ Converts a observed standalone module to a quantized standalone module by calling the fx convert api, currently using the same `is_reference` flag as parent, but we may changing this behavior in the future (e.g. separating quantization and lowering for standalone module as well) Args: - node: The call_module node of the observed standalone module - modules: named_module of original model - model: original model - is_reference: a flag from parent provided by user to decide if we want to produce a reference model or a fbgemm/qnnpack model - backend_config_dict: backend configuration of the target backend of quantization """ # TODO: remove is_reference flag if is_reference: convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx else: convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] # We know that observed standalone module is a GraphModule since # it's produced by us observed_standalone_module: GraphModule = modules[str( node.target)] # type: ignore[assignment] sm_input_quantized_idxs = \ observed_standalone_module \ ._standalone_module_input_quantized_idxs\ .tolist() # type: ignore[operator] # remove the dequantize nodes for inputs args = list(node.args) for idx in range(len(args)): if idx in sm_input_quantized_idxs: arg = args[idx] if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] quantize_node = arg.args[0] # type: ignore[union-attr] node.replace_input_with(arg, quantize_node) if len(arg.users) == 0: # type: ignore[union-attr] model.graph.erase_node(arg) # add dequantize node for output sm_output_quantized_idxs = \ observed_standalone_module \ ._standalone_module_output_quantized_idxs \ .tolist() # type: ignore[operator] if len(sm_output_quantized_idxs) > 0: assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" "output idxs = [0] is supported" # if it's non-empty, then it means the output is kept in quantized form # we'll just add a dequantize node after this node insert_dequantize_node(node, model.graph) # TODO: allow convert_custom_config to override backend_config_dict # for standalone module quantized_standalone_module = convert_fn( observed_standalone_module, backend_config_dict=backend_config_dict) parent_name, name = _parent_name(node.target) # update the modules dict setattr(modules[parent_name], name, quantized_standalone_module) modules[str(node.target)] = quantized_standalone_module
def replace_observer_with_quantize_dequantize_node( model: torch.nn.Module, graph: Graph, node: Node, modules: Dict[str, torch.nn.Module], node_name_to_scope: Dict[str, Tuple[str, type]], qconfig_map: Dict[str, QConfigAny]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) module_path, prefix = get_module_path_and_prefix( node, node_name_to_scope, qconfig_map) observer_module = modules[node.target] maybe_quantize_node_info = get_quantize_node_info(observer_module) # Skip replacing observers to quant/dequant nodes if the qconfigs of all # consumers and producers of this observer are None skip_replacement = all([ has_none_qconfig(n, qconfig_map) for n in list(node.args) + list(node.users.keys()) ]) if skip_replacement or maybe_quantize_node_info is None: # didn't find correponding quantize op and info for the observer_module # so we just remove the observer with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) else: # otherwise, we can convert the observer moduel call to quantize/dequantize node node_type, quantize_op, qparams = maybe_quantize_node_info # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): # TODO: we can add the information of whether a value needs to # be registered as an attribute in qparams dict itself if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( model, graph, module_path + prefix + key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node, )) node.replace_all_uses_with(dequantized_node) graph.erase_node(node)
def maybe_insert_input_observers_for_node( node: Node, qconfig: QConfigAny, model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, node_name_to_target_dtype: Dict[str, Any], qhandler: Optional[QuantizeHandler], prepare_custom_config_dict: Dict[str, Any], node_name_to_scope: Dict[str, Tuple[str, type]], ) -> None: """ If needed, inserts observers to the input args and kwargs of `node`. Note: modifies `node` inplace. For example, if cur_node needs an observer after prev_node, we change from prev_node -> cur_node To prev_node -> obs -> cur_node """ if qconfig is None: # if quantization is turned off for this node, we do not need # to insert input observers return assert qconfig is not None # Look through every input arg. If that arg's target dtype does not # match the current node's target dtype, insert an observer. new_args = [] for arg in node.args: new_arg = maybe_insert_input_observer_for_arg_or_kwarg( node, arg, qconfig, model, modules, graph, node_name_to_target_dtype, qhandler, prepare_custom_config_dict, node_name_to_scope) new_args.append(new_arg) new_kwargs = {} for k, kwarg in node.kwargs.items(): new_kwarg = maybe_insert_input_observer_for_arg_or_kwarg( node, kwarg, qconfig, model, modules, graph, node_name_to_target_dtype, qhandler, prepare_custom_config_dict, node_name_to_scope) new_kwargs[k] = new_kwarg # assign the new args and kwargs to the node, inplace node.args = tuple(new_args) node.kwargs = new_kwargs
def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node: """ Given a node, gets the n'th input to that node, normalizing args and kwargs to the best of its ability. """ try: norm_args_and_kwargs = node.normalized_arguments( gm, normalize_to_only_use_kwargs=True) if norm_args_and_kwargs is not None: norm_args, norm_kwargs = norm_args_and_kwargs assert len(norm_args) + len(norm_kwargs) > idx if idx < len(norm_args): return norm_args[idx] else: # note: in Python 3.7+ dicts are ordered return list(norm_kwargs.values())[idx] else: assert len(node.args) + len(node.kwargs) > idx if idx < len(node.args): return node.args[idx] # type: ignore[return-value] else: kwargs_idx = idx + len(node.args) return list(node.kwargs.values())[ kwargs_idx] # type: ignore[return-value] except RuntimeError: # this RuntimeError happens when node argument normalization # requires typehints to proceed, such as for torch.add where # either the first, second or both arguments could be tensors assert len(node.args) + len(node.kwargs) > idx if idx < len(node.args): return node.args[idx] # type: ignore[return-value] else: kwargs_idx = idx + len(node.args) return list( node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
def replace_observer_with_quantize_dequantize_node( graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] root_module = modules[""] if observer_module.dtype == torch.float32: # remove the node for now # TODO: support dynamic quant with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) elif observer_module.dtype in [ torch.quint8, torch.qint8, torch.float16 ]: node_type, quantize_op, qparams = get_quantize_node_info( observer_module) # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( root_module, graph, key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node, )) node.replace_all_uses_with(dequantized_node) graph.erase_node(node)
def convert_standalone_module( node: Node, modules: Dict[str, torch.nn.Module], model: torch.fx.GraphModule, is_reference: bool, backend_config_dict: Dict[str, Any]): convert = torch.ao.quantization._quantize_fx_do_not_use._convert_do_not_use # type: ignore[attr-defined] # We know that observed standalone module is a GraphModule since # it's produced by us observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment] sm_input_quantized_idxs = \ observed_standalone_module \ ._standalone_module_input_quantized_idxs\ .tolist() # type: ignore[operator] # remove the dequantize nodes for inputs args = list(node.args) for idx in range(len(args)): if idx in sm_input_quantized_idxs: arg = args[idx] if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] quantize_node = arg.args[0] # type: ignore[union-attr] node.replace_input_with(arg, quantize_node) if len(arg.users) == 0: # type: ignore[union-attr] model.graph.erase_node(arg) # add dequantize node for output sm_output_quantized_idxs = \ observed_standalone_module \ ._standalone_module_output_quantized_idxs \ .tolist() # type: ignore[operator] if len(sm_output_quantized_idxs) > 0: assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" "output idxs = [0] is supported" # if it's non-empty, then it means the output is kept in quantized form # we'll just add a dequantize node after this node insert_dequantize_node(node, model.graph) # TODO: allow convert_custom_config_dict to override backend_config_dict # for standalone module quantized_standalone_module = convert( observed_standalone_module, is_reference=True, backend_config_dict=backend_config_dict) parent_name, name = _parent_name(node.target) # update the modules dict setattr(modules[parent_name], name, quantized_standalone_module) modules[str(node.target)] = quantized_standalone_module
def maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph): """ If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, we'll recursively remove the dequantize Node """ if isinstance(arg, Node) and \ arg.op == "call_method" and \ arg.target == "dequantize": quantize_node = arg.args[0] # we only replace the specific use since dequantize could be used by other nodes # as well node.replace_input_with(arg, quantize_node) elif isinstance(arg, (list, tuple)): for arg_element in arg: maybe_recursive_remove_dequantize(arg_element, node, graph) elif isinstance(arg, dict): for arg_element in arg.values(): maybe_recursive_remove_dequantize(arg_element, node, graph) else: warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}")
def maybe_insert_input_equalization_observers_for_node( node: Node, equalization_qconfig: Any, model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, node_name_to_target_dtype: Dict[str, Any], is_branch: bool, node_name_to_scope: Dict[str, Tuple[str, type]], ) -> None: """ If `node` needs to be equalized, find the input/weight observers it needs in `equalization_qconfig`, creates them, and inserts it into `graph`. If `node` does not need an equalization observer, returns None. """ if equalization_qconfig is None or not node_supports_equalization(node, modules): return if is_branch: warnings.warn( f"Cannot equalize {node} because it is part of a branch." ) return new_args = [] for arg in node.args: if not isinstance(arg, Node) or node_arg_is_bias(node, arg): new_args.append(arg) continue is_weight = node_arg_is_weight(node, arg) act_eq_process_ctr = equalization_qconfig.weight if is_weight else \ equalization_qconfig.input_activation new_eq_obs_mod = act_eq_process_ctr() new_eq_obs_node = insert_observer( arg, node, new_eq_obs_mod, model, modules, graph, node_name_to_scope, "input") # set the type, so the next node can read it node_name_to_target_dtype[new_eq_obs_node.name] = node_name_to_target_dtype[arg.name] new_args.append(new_eq_obs_node) # assign the new args and kwargs to the node, inplace node.args = tuple(new_args)
def maybe_insert_observers_before_graph_output( graph_output_node: Node, output_quantized_idxs: List[int], node_name_to_target_dtype: Dict[str, torch.dtype], qconfig_map: Dict[str, QConfigAny], model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, ) -> None: """ If the output needs to be quantized and there are any nodes in the output which are not already observed, inserts observers for those nodes. """ # TODO(future PR): update the output_quantized_idxs API to match # arbitrary data structures. There is always a single output, and # that output can have arbitrary nesting of values. List[int] is # not the right data type for this. assert output_quantized_idxs == [0] or output_quantized_idxs == [], \ 'unrecognized format of output_quantized_idxs' # Currently dequants are inserted in the convert step. So, we only # have to do anything if the output is hardcoded to be quantized if output_quantized_idxs == []: return # TODO(future PR): support more dtypes in model outputs, if necessary output_target_dtype = torch.quint8 def _recursive_maybe_replace_node_with_obs( maybe_node: Argument, target_dtype: torch.dtype, node_name_to_target_dtype: Dict[str, torch.dtype], qconfig_map: Dict[str, QConfigAny], model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, ) -> Argument: """ Navigate an arbitrary data structure of lists, tuples, dicts. For each container type, recurse on all inputs. Once any Node is found, insert an observer if needed and do not recurse further. For example, given a structure of {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}} we recurse down to bar1 and bar3, observe them if necessary, and if we inserted an observer then replace the original node with its observer. Returns the data structure with all nodes needing observation being replaced by their observers. """ if isinstance(maybe_node, Node): # check dtype of this node this_node_dtype = node_name_to_target_dtype[maybe_node.name] if this_node_dtype != target_dtype: # insert observer qconfig = qconfig_map.get(maybe_node.name) # TODO(future PR): see if we need to allow specifying qconfig # on output nodes, to remove the restriction below. assert qconfig is not None, \ 'Quantizing the output node without a qconfig is not supported' observer_mod = qconfig.activation() observer_node = insert_observer( maybe_node, observer_mod, model, modules, graph) return observer_node else: return maybe_node elif isinstance(maybe_node, (list, tuple)): results = [] for inner_node in maybe_node: results.append(_recursive_maybe_replace_node_with_obs( inner_node, target_dtype, node_name_to_target_dtype, qconfig_map, model, modules, graph)) if isinstance(maybe_node, list): return results else: return tuple(results) elif isinstance(maybe_node, dict): results_dict = {} for k, inner_v in maybe_node.items(): results_dict[k] = _recursive_maybe_replace_node_with_obs( inner_v, target_dtype, node_name_to_target_dtype, qconfig_map, model, modules, graph) return results_dict else: return results new_args = [] for old_arg in graph_output_node.args: new_args.append( _recursive_maybe_replace_node_with_obs( old_arg, output_target_dtype, node_name_to_target_dtype, qconfig_map, model, modules, graph)) graph_output_node.args = new_args # type: ignore[assignment]
def convert_custom_module( node: Node, graph: Graph, modules: Dict[str, torch.nn.Module], custom_module_class_mapping: Dict[Callable, Callable], statically_quantized_custom_module_nodes: Set[Node]): """ Converts an observed custom module to a quantized custom module based on `custom_module_class_mapping` For static quantization, we'll also remove the previous `dequantize` node and attach the observer node for output to the module, the observer for the node will be converted to a dequantize node instead of quantize-dequantize pairs later in the graph. In the end we would have a quantized custom module that has the same interface as a default quantized module in nn.quantized namespace, i.e. quantized input and quantized output. Args: - node: The call_module node of the observed standalone module - graph: The graph containing the node - modules: named_module of original model - custom_module_class_mapping: mapping from observed custom module class to quantized custom module class, used to swap custom modules - statically_quantized_custom_module_nodes: we'll add the custom module node if we find it is statically quantized, this will be used later when converting observers to quant/dequant node pairs, if the observed node is a statically quantized custom module nodes, we'll convert the observer to a dequantize node, this is to keep the interface the same as the default quantized module. TODO: maybe we want to redesign this part to align with reference model design as well, but there has been some discussions around the interface, so we can do it later. """ observed_custom_module = modules[str(node.target)] maybe_obs = maybe_get_observer_for_node(node, modules) qconfig = observed_custom_module.qconfig if activation_is_statically_quantized(qconfig): statically_quantized_custom_module_nodes.add(node) # remove the previous dequant node prev_node = node.args[0] # expecting the input node for a custom module node to be a Node assert isinstance(prev_node, Node), \ f"Expecting the argument for custom module node to be a Node, but got {prev_node}" if prev_node.op == "call_method" and prev_node.target == "dequantize": # change the connection for custom module, we'll change the input # of custom module node to quantize node: # Before: quantize - dequantize - custom - module # After: quantize - custom - module # \ - dequantize node.replace_input_with(prev_node, prev_node.args[0]) # Remove the dequantize node if it doesn't have other users if len(prev_node.users) == 0: graph.erase_node(prev_node) # absorb the following observer into the module conversion activation_post_process = maybe_get_observer_for_node(node, modules) assert activation_post_process is not None observed_custom_module.activation_post_process = activation_post_process # swap the observed custom module to quantized custom module quantized_custom_module_class = get_swapped_custom_module_class( observed_custom_module, custom_module_class_mapping, qconfig) quantized_custom_module = \ quantized_custom_module_class.from_observed(observed_custom_module) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, quantized_custom_module)
def _insert_copy_of_node_a_after_input_node_c( input_node_c: Union[Node, List[Node]], input_node_c_2: Optional[Union[Node, List[Node]]], node_a: Node, gm_a: GraphModule, gm_b: GraphModule, node_name_prefix: str, ) -> Node: """ Assume that node_a from graph_a has args (input, (input2)?, arg1, ...), and kwargs {kw0: kwarg0, ...} Note: input2 is optional. If it equals to None, we assume that the op has a single non-param input. If it is specified, we assume that the op has two non-param inputs. Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b, and creates the corresponding nodes in graph_c. Note: observers are ignored, so if an arg is an observer we navigate up until we find a non-observer parent. If node_a is a call_module, points the module pointed to by node_a to gm_b. Creates the copy of node_a in graph_c, with input as the first arg, and all other args and kwargs pointing to the copies of the objects in gm_b created above. An example in pictures: graph A: ======== input -------------> node_a / / / (input_2)?----------/ / / / / weight -> weight_obs / / bias ---------------- graph C (derived from B): ========================= input_node_c --> node_a_copy / / / (input_node_c_2)? / / / / weight_copy ----/ / / bias_copy ------/ """ if isinstance(input_node_c, Node): graph_c = input_node_c.graph else: assert isinstance(input_node_c, list) graph_c = input_node_c[0].graph norm_args_kwargs = node_a.normalized_arguments( gm_a, normalize_to_only_use_kwargs=True) if norm_args_kwargs is not None: norm_args, norm_kwargs = norm_args_kwargs else: norm_args, norm_kwargs = node_a.args, node_a.kwargs new_args = [] new_kwargs = {} def _copy_arg(arg): # copy the other inputs from the other graph if isinstance(arg, Node): arg = return_first_non_observer_node(arg, gm_a) arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c) return arg elif isinstance(arg, (int, float, torch.dtype)): return arg elif isinstance(kwarg_val, (list, tuple)): for el in kwarg_val: assert not isinstance(el, Node), \ "handling of Node inside list is not implemented" return arg else: raise AssertionError( f"handling for kwarg of type {type(kwarg_val)} is not implemented" ) cur_idx = 0 while cur_idx < len(norm_args): if cur_idx == 0: new_arg = input_node_c elif cur_idx == 1 and input_node_c_2 is not None: new_arg = input_node_c_2 else: new_arg = _copy_arg(norm_args[cur_idx]) new_args.append(new_arg) cur_idx += 1 for kwarg_name, kwarg_val in norm_kwargs.items(): # stitch the inputs from base graph if cur_idx == 0: new_kwargs[kwarg_name] = input_node_c elif cur_idx == 1 and input_node_c_2 is not None: new_kwargs[kwarg_name] = input_node_c_2 else: new_kwargs[kwarg_name] = _copy_arg(kwarg_val) cur_idx += 1 new_args = tuple(new_args) # type: ignore[assignment] node_a_shadows_c_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) if node_a.op == 'call_module': # if target is a module, we point to the module from gm_b new_mod_copy_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) # fetch the corresponding module from gm_a assert isinstance(node_a.target, str) mod_a = getattr_from_fqn(gm_a, node_a.target) setattr(gm_b, new_mod_copy_name, mod_a) node_a_shadows_c = graph_c.create_node(node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name) return node_a_shadows_c else: assert node_a.op in ('call_function', 'call_method') node_a_shadows_c = graph_c.create_node(node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name) return node_a_shadows_c