def replace_observer_with_dequantize_node(node: Node, graph: Graph): call_custom_module_node = node.args[0] assert isinstance(call_custom_module_node, Node), \ f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" node.replace_all_uses_with(call_custom_module_node) graph.erase_node(node) insert_dequantize_node(call_custom_module_node, graph)
def convert_custom_module(node: Node, graph: Graph, modules: Dict[str, torch.nn.Module], custom_module_class_mapping: Dict[Callable, Callable], statically_quantized_custom_module_nodes: Set[Node]): """ Converts an observed custom module to a quantized custom module based on `custom_module_class_mapping` For static quantization, we'll also remove the previous `dequantize` node and attach the observer node for output to the module, the observer for the node will be converted to a dequantize node instead of quantize-dequantize pairs later in the graph. In the end we would have a quantized custom module that has the same interface as a default quantized module in nn.quantized namespace, i.e. quantized input and quantized output. Args: - node: The call_module node of the observed standalone module - graph: The graph containing the node - modules: named_module of original model - custom_module_class_mapping: mapping from observed custom module class to quantized custom module class, used to swap custom modules - statically_quantized_custom_module_nodes: we'll add the custom module node if we find it is statically quantized, this will be used later when converting observers to quant/dequant node pairs, if the observed node is a statically quantized custom module nodes, we'll convert the observer to a dequantize node, this is to keep the interface the same as the default quantized module. TODO: maybe we want to redesign this part to align with reference model design as well, but there has been some discussions around the interface, so we can do it later. """ observed_custom_module = modules[str(node.target)] maybe_obs = maybe_get_observer_for_node(node, modules) qconfig = observed_custom_module.qconfig if activation_is_statically_quantized(qconfig): statically_quantized_custom_module_nodes.add(node) # remove the previous dequant node prev_node = node.args[0] # expecting the input node for a custom module node to be a Node assert isinstance(prev_node, Node), \ f"Expecting the argument for custom module node to be a Node, but got {prev_node}" if prev_node.op == "call_method" and prev_node.target == "dequantize": assert len(prev_node.users ) == 1, "dequantize node before custom module is used " "multiple times, this is currently not supported yet, but it can be " "supported by duplicating the dequantize nodes in these cases" prev_node.replace_all_uses_with(prev_node.args[0]) graph.erase_node(prev_node) # absorb the following observer into the module conversion activation_post_process = maybe_get_observer_for_node(node, modules) assert activation_post_process is not None observed_custom_module.activation_post_process = activation_post_process # swap the observed custom module to quantized custom module quantized_custom_module_class = get_swapped_custom_module_class( observed_custom_module, custom_module_class_mapping, qconfig) quantized_custom_module = \ quantized_custom_module_class.from_observed(observed_custom_module) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, quantized_custom_module)
def replace_observer_with_quantize_dequantize_node( model: torch.nn.Module, graph: Graph, node: Node, modules: Dict[str, torch.nn.Module], node_name_to_scope: Dict[str, Tuple[str, type]], qconfig_map: Dict[str, QConfigAny]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) module_path, prefix = get_module_path_and_prefix( node, node_name_to_scope, qconfig_map) observer_module = modules[node.target] maybe_quantize_node_info = get_quantize_node_info(observer_module) # Skip replacing observers to quant/dequant nodes if the qconfigs of all # consumers and producers of this observer are None skip_replacement = all([ has_none_qconfig(n, qconfig_map) for n in list(node.args) + list(node.users.keys()) ]) if skip_replacement or maybe_quantize_node_info is None: # didn't find correponding quantize op and info for the observer_module # so we just remove the observer with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) else: # otherwise, we can convert the observer moduel call to quantize/dequantize node node_type, quantize_op, qparams = maybe_quantize_node_info # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): # TODO: we can add the information of whether a value needs to # be registered as an attribute in qparams dict itself if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( model, graph, module_path + prefix + key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node, )) node.replace_all_uses_with(dequantized_node) graph.erase_node(node)
def replace_observer_with_quantize_dequantize_node( graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] root_module = modules[""] if observer_module.dtype == torch.float32: # remove the node for now # TODO: support dynamic quant with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) elif observer_module.dtype in [ torch.quint8, torch.qint8, torch.float16 ]: node_type, quantize_op, qparams = get_quantize_node_info( observer_module) # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( root_module, graph, key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node, )) node.replace_all_uses_with(dequantized_node) graph.erase_node(node)