def convert_standalone_module(node: Node, modules: Dict[str, torch.nn.Module], model: torch.fx.GraphModule, is_reference: bool, backend_config_dict: Optional[Dict[str, Any]]): """ Converts a observed standalone module to a quantized standalone module by calling the fx convert api, currently using the same `is_reference` flag as parent, but we may changing this behavior in the future (e.g. separating quantization and lowering for standalone module as well) Args: - node: The call_module node of the observed standalone module - modules: named_module of original model - model: original model - is_reference: a flag from parent provided by user to decide if we want to produce a reference model or a fbgemm/qnnpack model - backend_config_dict: backend configuration of the target backend of quantization """ # TODO: remove is_reference flag if is_reference: convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx else: convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] # We know that observed standalone module is a GraphModule since # it's produced by us observed_standalone_module: GraphModule = modules[str( node.target)] # type: ignore[assignment] sm_input_quantized_idxs = \ observed_standalone_module \ ._standalone_module_input_quantized_idxs\ .tolist() # type: ignore[operator] # remove the dequantize nodes for inputs args = list(node.args) for idx in range(len(args)): if idx in sm_input_quantized_idxs: arg = args[idx] if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] quantize_node = arg.args[0] # type: ignore[union-attr] node.replace_input_with(arg, quantize_node) if len(arg.users) == 0: # type: ignore[union-attr] model.graph.erase_node(arg) # add dequantize node for output sm_output_quantized_idxs = \ observed_standalone_module \ ._standalone_module_output_quantized_idxs \ .tolist() # type: ignore[operator] if len(sm_output_quantized_idxs) > 0: assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" "output idxs = [0] is supported" # if it's non-empty, then it means the output is kept in quantized form # we'll just add a dequantize node after this node insert_dequantize_node(node, model.graph) # TODO: allow convert_custom_config to override backend_config_dict # for standalone module quantized_standalone_module = convert_fn( observed_standalone_module, backend_config_dict=backend_config_dict) parent_name, name = _parent_name(node.target) # update the modules dict setattr(modules[parent_name], name, quantized_standalone_module) modules[str(node.target)] = quantized_standalone_module
def convert_standalone_module( node: Node, modules: Dict[str, torch.nn.Module], model: torch.fx.GraphModule, is_reference: bool, backend_config_dict: Dict[str, Any]): convert = torch.ao.quantization._quantize_fx_do_not_use._convert_do_not_use # type: ignore[attr-defined] # We know that observed standalone module is a GraphModule since # it's produced by us observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment] sm_input_quantized_idxs = \ observed_standalone_module \ ._standalone_module_input_quantized_idxs\ .tolist() # type: ignore[operator] # remove the dequantize nodes for inputs args = list(node.args) for idx in range(len(args)): if idx in sm_input_quantized_idxs: arg = args[idx] if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] quantize_node = arg.args[0] # type: ignore[union-attr] node.replace_input_with(arg, quantize_node) if len(arg.users) == 0: # type: ignore[union-attr] model.graph.erase_node(arg) # add dequantize node for output sm_output_quantized_idxs = \ observed_standalone_module \ ._standalone_module_output_quantized_idxs \ .tolist() # type: ignore[operator] if len(sm_output_quantized_idxs) > 0: assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" "output idxs = [0] is supported" # if it's non-empty, then it means the output is kept in quantized form # we'll just add a dequantize node after this node insert_dequantize_node(node, model.graph) # TODO: allow convert_custom_config_dict to override backend_config_dict # for standalone module quantized_standalone_module = convert( observed_standalone_module, is_reference=True, backend_config_dict=backend_config_dict) parent_name, name = _parent_name(node.target) # update the modules dict setattr(modules[parent_name], name, quantized_standalone_module) modules[str(node.target)] = quantized_standalone_module
def maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph): """ If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, we'll recursively remove the dequantize Node """ if isinstance(arg, Node) and \ arg.op == "call_method" and \ arg.target == "dequantize": quantize_node = arg.args[0] # we only replace the specific use since dequantize could be used by other nodes # as well node.replace_input_with(arg, quantize_node) elif isinstance(arg, (list, tuple)): for arg_element in arg: maybe_recursive_remove_dequantize(arg_element, node, graph) elif isinstance(arg, dict): for arg_element in arg.values(): maybe_recursive_remove_dequantize(arg_element, node, graph) else: warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}")
def convert_custom_module( node: Node, graph: Graph, modules: Dict[str, torch.nn.Module], custom_module_class_mapping: Dict[Callable, Callable], statically_quantized_custom_module_nodes: Set[Node]): """ Converts an observed custom module to a quantized custom module based on `custom_module_class_mapping` For static quantization, we'll also remove the previous `dequantize` node and attach the observer node for output to the module, the observer for the node will be converted to a dequantize node instead of quantize-dequantize pairs later in the graph. In the end we would have a quantized custom module that has the same interface as a default quantized module in nn.quantized namespace, i.e. quantized input and quantized output. Args: - node: The call_module node of the observed standalone module - graph: The graph containing the node - modules: named_module of original model - custom_module_class_mapping: mapping from observed custom module class to quantized custom module class, used to swap custom modules - statically_quantized_custom_module_nodes: we'll add the custom module node if we find it is statically quantized, this will be used later when converting observers to quant/dequant node pairs, if the observed node is a statically quantized custom module nodes, we'll convert the observer to a dequantize node, this is to keep the interface the same as the default quantized module. TODO: maybe we want to redesign this part to align with reference model design as well, but there has been some discussions around the interface, so we can do it later. """ observed_custom_module = modules[str(node.target)] maybe_obs = maybe_get_observer_for_node(node, modules) qconfig = observed_custom_module.qconfig if activation_is_statically_quantized(qconfig): statically_quantized_custom_module_nodes.add(node) # remove the previous dequant node prev_node = node.args[0] # expecting the input node for a custom module node to be a Node assert isinstance(prev_node, Node), \ f"Expecting the argument for custom module node to be a Node, but got {prev_node}" if prev_node.op == "call_method" and prev_node.target == "dequantize": # change the connection for custom module, we'll change the input # of custom module node to quantize node: # Before: quantize - dequantize - custom - module # After: quantize - custom - module # \ - dequantize node.replace_input_with(prev_node, prev_node.args[0]) # Remove the dequantize node if it doesn't have other users if len(prev_node.users) == 0: graph.erase_node(prev_node) # absorb the following observer into the module conversion activation_post_process = maybe_get_observer_for_node(node, modules) assert activation_post_process is not None observed_custom_module.activation_post_process = activation_post_process # swap the observed custom module to quantized custom module quantized_custom_module_class = get_swapped_custom_module_class( observed_custom_module, custom_module_class_mapping, qconfig) quantized_custom_module = \ quantized_custom_module_class.from_observed(observed_custom_module) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, quantized_custom_module)