def fuse(self, model, inplace=False): if not inplace: model = copy.deepcopy(model) input_root = model input_graph = model.graph self.modules = dict(input_root.named_modules()) fusion_patterns = get_default_fusion_patterns() # find fusion fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns) self.fused_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: root_node, obj = fusion_pairs.get(node.name, (None, None)) if root_node is node: env[node.name] = obj.fuse(self, load_arg) elif root_node is None: env[node.name] = self.fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here model = GraphModule(input_root, self.fused_graph) return model
def insert_observer( node: Node, observed_op: Node, observer: ObserverBase, model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, node_name_to_scope: Dict[str, Tuple[str, type]], input_or_output: str, ) -> Node: """ Attaches `observer` to `model`, and creates a node which calls `observer` on the output of `node`. """ model_device = assert_and_get_unique_device(model) if model_device: observer.to(model_device) # add observer module as attribute # NOTE: We get the FQN of the module/op being observed here using the node_name_to_scope # Please don't change/update this behavior as it might impact how observer stats are transferred # from the train model to the inference model for some models. obs_name_prefix, _ = node_name_to_scope[observed_op.name] obs_name_prefix = node.name if obs_name_prefix == '' else obs_name_prefix if is_equalization_observer(observer): prefix = node.name + '_equalization_process_' else: prefix = obs_name_prefix + '_' + input_or_output + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) modules[observer_name] = observer with graph.inserting_after(node): new_obs = graph.create_node( 'call_module', observer_name, (node,), {}) return new_obs
def remove_qconfig_observer_fx(model): # remove activation post process act_post_process_removed_graph = Graph() env = {} # type: Dict[str, Any] modules = dict(model.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in model.graph.nodes: if node.op == "output": act_post_process_removed_graph.output( map_arg(node.args[0], load_arg)) continue if node.op == "call_module" and is_activation_post_process( modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] else: env[node.name] = act_post_process_removed_graph.node_copy( node, load_arg) _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model
def _insert_quantize_per_tensor_node( prev_node_c: Node, node_a: Node, gm_b: GraphModule, graph_c: Graph, scale: Union[torch.Tensor, float], zero_point: Union[torch.Tensor, int], dtype_cast_name: str, ) -> Node: # copy scale scale_node_name = \ get_new_attr_name_with_prefix( node_a.name + '_input_scale_')(gm_b) setattr(gm_b, scale_node_name, scale) scale_node = graph_c.create_node('get_attr', scale_node_name, (), {}, scale_node_name) # copy zero_point zero_point_node_name = \ get_new_attr_name_with_prefix( node_a.name + '_input_zero_point_')(gm_b) setattr(gm_b, zero_point_node_name, zero_point) zero_point_node = graph_c.create_node('get_attr', zero_point_node_name, (), {}, zero_point_node_name) # create the quantize_per_tensor call return graph_c.create_node( 'call_function', torch.quantize_per_tensor, (prev_node_c, scale_node, zero_point_node, torch.quint8), {}, dtype_cast_name)
def insert_observer( node: Node, observer: torch.quantization.ObserverBase, model: torch.nn.Module, modules: Dict[str, torch.nn.Module], graph: Graph, ) -> Node: """ Attaches `observer` to `model`, and creates a node which calls `observer` on the output of `node`. """ model_device = assert_and_get_unique_device(model) if model_device: observer.to(model_device) # add observer module as attribute if is_equalization_observer(observer): prefix = node.name + '_equalization_process_' else: prefix = node.name + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) modules[observer_name] = observer with graph.inserting_after(node): new_obs = graph.create_node( 'call_module', observer_name, (node,), {}) return new_obs
def fuse(self, model, fuse_custom_config_dict=None): if fuse_custom_config_dict is None: fuse_custom_config_dict = {} input_root = model input_graph = model.graph self.modules = dict(input_root.named_modules()) additional_fusion_patterns = fuse_custom_config_dict.get( "additional_quant_pattern", {}) fusion_patterns = get_default_fusion_patterns().copy() for k, v in additional_fusion_patterns.items(): fusion_patterns[k] = v # find fusion fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns) self.fused_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: root_node, obj = fusion_pairs.get(node.name, (None, None)) if root_node is node: env[node.name] = obj.fuse(self, load_arg) elif root_node is None: env[node.name] = self.fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here model = GraphModule(input_root, self.fused_graph) return model
def fuse(self, model: GraphModule, fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: if fuse_custom_config_dict is None: fuse_custom_config_dict = {} input_root = model input_graph = model.graph self.modules = dict(input_root.named_modules()) additional_fusion_patterns = \ fuse_custom_config_dict.get("additional_fusion_pattern", {}) fusion_patterns = get_combined_dict( get_default_fusion_patterns(), additional_fusion_patterns) # find fusion fusion_pairs = self._find_matches( input_root, input_graph, fusion_patterns) self.fused_graph = Graph() env: Dict[Any, Any] = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: root_node, obj = fusion_pairs.get(node.name, (None, None)) if root_node is node: assert obj is not None env[node.name] = obj.fuse(self, load_arg) elif root_node is None: env[node.name] = self.fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here model = GraphModule(input_root, self.fused_graph) return model
def replace_observer_with_dequantize_node(node: Node, graph: Graph): call_custom_module_node = node.args[0] assert isinstance(call_custom_module_node, Node), \ f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" node.replace_all_uses_with(call_custom_module_node) graph.erase_node(node) insert_dequantize_node(call_custom_module_node, graph)
def convert_custom_module(node: Node, graph: Graph, modules: Dict[str, torch.nn.Module], custom_module_class_mapping: Dict[Callable, Callable], statically_quantized_custom_module_nodes: Set[Node]): """ Converts an observed custom module to a quantized custom module based on `custom_module_class_mapping` For static quantization, we'll also remove the previous `dequantize` node and attach the observer node for output to the module, the observer for the node will be converted to a dequantize node instead of quantize-dequantize pairs later in the graph. In the end we would have a quantized custom module that has the same interface as a default quantized module in nn.quantized namespace, i.e. quantized input and quantized output. Args: - node: The call_module node of the observed standalone module - graph: The graph containing the node - modules: named_module of original model - custom_module_class_mapping: mapping from observed custom module class to quantized custom module class, used to swap custom modules - statically_quantized_custom_module_nodes: we'll add the custom module node if we find it is statically quantized, this will be used later when converting observers to quant/dequant node pairs, if the observed node is a statically quantized custom module nodes, we'll convert the observer to a dequantize node, this is to keep the interface the same as the default quantized module. TODO: maybe we want to redesign this part to align with reference model design as well, but there has been some discussions around the interface, so we can do it later. """ observed_custom_module = modules[str(node.target)] maybe_obs = maybe_get_observer_for_node(node, modules) qconfig = observed_custom_module.qconfig if activation_is_statically_quantized(qconfig): statically_quantized_custom_module_nodes.add(node) # remove the previous dequant node prev_node = node.args[0] # expecting the input node for a custom module node to be a Node assert isinstance(prev_node, Node), \ f"Expecting the argument for custom module node to be a Node, but got {prev_node}" if prev_node.op == "call_method" and prev_node.target == "dequantize": assert len(prev_node.users ) == 1, "dequantize node before custom module is used " "multiple times, this is currently not supported yet, but it can be " "supported by duplicating the dequantize nodes in these cases" prev_node.replace_all_uses_with(prev_node.args[0]) graph.erase_node(prev_node) # absorb the following observer into the module conversion activation_post_process = maybe_get_observer_for_node(node, modules) assert activation_post_process is not None observed_custom_module.activation_post_process = activation_post_process # swap the observed custom module to quantized custom module quantized_custom_module_class = get_swapped_custom_module_class( observed_custom_module, custom_module_class_mapping, qconfig) quantized_custom_module = \ quantized_custom_module_class.from_observed(observed_custom_module) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, quantized_custom_module)
def fold_weight( quantized: QuantizedGraphModule, node_name_to_scope: Dict[str, Tuple[str, type]]) -> QuantizedGraphModule: """ Trace back from the weight node util we hit getattr, reconstruct the graph module with the traced nodes and run the graph module to pack the weight. then replace the original chain of ops with the packed weight. """ packed_weights = dict() # map from folded node name to the prepacked weight name folded_nodes = dict() # get packed weights for node in quantized.graph.nodes: if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS: nodes_to_fold = collect_producer_nodes(node) if nodes_to_fold is not None: for node_to_fold in nodes_to_fold: folded_nodes[node_to_fold.name] = node prepacking_module = graph_module_from_producer_nodes( quantized, nodes_to_fold) packed_weight = prepacking_module() packed_weights[node.name] = packed_weight # remove folded nodes and replace the prepacking node with getattr folded_graph = Graph() env: Dict[Any, Any] = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) quantized_root = quantized quantized_graph = quantized.graph for node in quantized_graph.nodes: prepack_node = folded_nodes.get(node.name, None) if prepack_node is node: packed_weight = packed_weights[node.name] # add a prepacked attribute to root op_node = list(prepack_node.users)[0] module_path, _ = node_name_to_scope[op_node.name] get_new_packed_weight_name = \ get_new_attr_name_with_prefix(module_path + '_packed_weight_') packed_weight_name = get_new_packed_weight_name(quantized_root) setattr(quantized_root, packed_weight_name, packed_weight) # replace prepack node with a getattr node env[node.name] = folded_graph.create_node('get_attr', packed_weight_name, (), {}) elif prepack_node is not None: # remove the foled node continue else: # copy other nodes env[node.name] = folded_graph.node_copy(node, load_arg) quantized = QuantizedGraphModule(quantized_root, folded_graph, quantized_root.preserved_attr_names) return quantized
def insert_dequantize_node(node: Node, graph: Graph): """ Inserts dequantize node for `node` in `graph` """ with graph.inserting_after(node): dequantize_node = graph.call_method("dequantize", (node, )) for user_node in dict(node.users): if user_node is not dequantize_node: user_node.replace_input_with(node, dequantize_node)
def __init__(self): self.graph = Graph() # Private attributes for generating names self._tensor_name_counter = 0 self._dim_name_counter = 0 self._shape_name_counter = 0 self._lowercase = tuple(string.ascii_lowercase) self._uppercase = tuple(string.ascii_uppercase)
def legalize_graph(gm: GraphModule): """ Replace the graph of the given GraphModule with one that contains the same nodes as the original, but in topologically sorted order. This is used by the merge_matmul transformation below, which disturbs the topologically sorted order of its input GraphModule, so that this order is restored before further transformation. Arguments: gm: The graph module to topologically sort. It is modified in-place. """ # Build an adjacency list representation of node dependencies in the graph. This also # serves as a list of nodes that still need to be inserted into the new, topologically # sorted graph. dependencies = { node: node.all_input_nodes.copy() for node in gm.graph.nodes } # Construct a new graph that will contain all nodes in topologically sorted order. new_graph = Graph() value_remap: Dict[Node, Node] = {} # Copy over all nodes with no dependencies. for node, deps in dependencies.items(): if not deps: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) # Remove the copied over nodes from the adjacency list. for copied_node in value_remap.keys(): del dependencies[copied_node] # While there are still nodes to insert into the new graph: while dependencies: copied_this_round = [] # Copy over all nodes whose dependencies already exist in the new graph. for node, deps in dependencies.items(): all_deps_copied = True for dep in deps: if dep not in value_remap: all_deps_copied = False if all_deps_copied: value_remap[node] = new_graph.node_copy( node, lambda n: value_remap[n]) copied_this_round.append(node) # Delete all nodes copied over in this iteration from dependencies. for copied_node in copied_this_round: del dependencies[copied_node] # Replace the old graph with the new, topologically sorted one. gm.graph = new_graph
def _insert_dtype_cast_after_node( node_a: Node, node_c: Node, prev_node_c: Union[Node, List[Node]], gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, node_name_prefix: str, ) -> Union[Node, List[Node]]: """ Given a starting graph C (derived from graph B) of ... -> prev_node_c -> node_c -> ... And a corresponding related node_a, inserts the correct dtype cast node after prev_node_c to cast into the dtype expected by node_a, resulting in: dtype_cast / ... -> prev_node_c -> node_c -> ... For example, if node_c is an int8 op and node_a is an fp32 op, this function will insert a dequant. """ dtype_cast_op = None node_input_type_a = get_node_input_type(node_a, gm_a) node_input_type_c = get_node_input_type(node_c, gm_b) if node_input_type_a == NodeInputType.FP32 and node_input_type_c == NodeInputType.INT8: dtype_cast_op = torch.dequantize else: raise AssertionError( f"dtype cast from {node_input_type_c} to {node_input_type_a} needs to be implemented" ) if isinstance(prev_node_c, Node): new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) return graph_c.create_node('call_function', dtype_cast_op, (prev_node_c, ), {}, new_dtype_cast_name) elif isinstance(prev_node_c, list): results = [] for prev_node_c_inner in prev_node_c: new_dtype_cast_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) new_dtype_cast_node = graph_c.create_node('call_function', dtype_cast_op, (prev_node_c_inner, ), {}, new_dtype_cast_name) results.append(new_dtype_cast_node) return results else: raise AssertionError(f"type f{type(prev_node_c)} is not handled")
class Fuser: def fuse(self, model, inplace=False): if not inplace: model = copy.deepcopy(model) input_root = model input_graph = model.graph self.modules = dict(input_root.named_modules()) fusion_patterns = get_fusion_patterns() # find fusion fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns) self.fused_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: root_node, obj = fusion_pairs.get(node.name, (None, None)) if root_node is node: env[node.name] = obj.fuse(self, load_arg) elif root_node is None: env[node.name] = self.fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here self.fused_graph.output(load_arg(input_graph.result)) model = GraphModule(input_root, self.fused_graph) return model def _find_matches(self, root, graph, patterns): modules = dict(root.named_modules()) match_map = {} # node name -> (root_node, match_value?) def apply_match(pattern, node, match): if isinstance(pattern, tuple): s, *args = pattern apply_match(s, node, match) for subpattern, arg in zip(args, node.args): apply_match(subpattern, arg, match) else: # the first pattern matches will take precedence if node.name not in match_map: match_map[node.name] = match for node in reversed(graph.nodes): if node.name not in match_map: for pattern, value in patterns.items(): if is_match(modules, node, pattern): apply_match(pattern, node, (node, value(self, node))) return match_map
def fuse( self, model: GraphModule, fuse_custom_config_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, ) -> GraphModule: if fuse_custom_config_dict is None: fuse_custom_config_dict = {} input_root = model input_graph = model.graph self.modules = dict(input_root.named_modules()) if backend_config_dict is None: additional_fusion_patterns = \ fuse_custom_config_dict.get("additional_fusion_pattern", {}) fusion_pattern_to_fuse_handler_cls = get_combined_dict( get_default_fusion_patterns(), additional_fusion_patterns) fuser_method_mapping = None else: fusion_pattern_to_fuse_handler_cls = get_fusion_pattern_to_fuse_handler_cls(backend_config_dict) fuser_method_mapping = get_fuser_method_mapping(backend_config_dict) # find fusion fusion_pairs = self._find_matches( input_root, input_graph, fusion_pattern_to_fuse_handler_cls) self.fused_graph = Graph() env: Dict[Any, Any] = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: maybe_last_node, pattern, matched_node_pattern, obj = \ fusion_pairs.get(node.name, (None, None, None, None)) if maybe_last_node is node: assert obj is not None # TODO: currently we hard code the root node, which only works for # a tuple of two nodes, we want to make this more general to # support more complex patterns root_node = matched_node_pattern[-1] # type: ignore[index] env[node.name] = obj.fuse( self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type] fuse_custom_config_dict, fuser_method_mapping) elif maybe_last_node is None: env[node.name] = self.fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", [])) model = FusedGraphModule(input_root, self.fused_graph, preserved_attributes) return model
def _copy_node_from_a_to_c( node_a: Node, gm_a: GraphModule, gm_b: GraphModule, graph_c: Graph, ) -> Node: """ Simple copy of node_a to graph_c. """ if node_a.op == 'get_attr': node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type] if torch.is_tensor(node_a_obj): node_a_obj = node_a_obj.detach() setattr(gm_b, node_a_copy_name, node_a_obj) node_a_copy = graph_c.create_node(node_a.op, node_a_copy_name, (), {}, node_a_copy_name) return node_a_copy elif node_a.op == 'call_method': assert node_a.target in ('dequantize', 'to'), \ f"target {node_a.target} is not implemented" if node_a.target == 'dequantize': arg_copy = _copy_node_from_a_to_c( get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type] node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_copy = graph_c.create_node(node_a.op, node_a.target, (arg_copy, ), {}, node_a_copy_name) return node_a_copy else: # to arg_copy = _copy_node_from_a_to_c( get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type] node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_copy = graph_c.create_node( node_a.op, node_a.target, (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)), {}, node_a_copy_name) return node_a_copy else: raise AssertionError( f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented" )
class Fuser: def fuse(self, model: GraphModule, fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: if fuse_custom_config_dict is None: fuse_custom_config_dict = {} input_root = model input_graph = model.graph self.modules = dict(input_root.named_modules()) additional_fusion_patterns = \ fuse_custom_config_dict.get("additional_fusion_pattern", {}) fusion_patterns = get_combined_dict( get_default_fusion_patterns(), additional_fusion_patterns) # find fusion fusion_pairs = self._find_matches( input_root, input_graph, fusion_patterns) self.fused_graph = Graph() env: Dict[Any, Any] = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: root_node, obj = fusion_pairs.get(node.name, (None, None)) if root_node is node: assert obj is not None env[node.name] = obj.fuse(self, load_arg) elif root_node is None: env[node.name] = self.fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here model = GraphModule(input_root, self.fused_graph) return model def _find_matches( self, root: GraphModule, graph: Graph, patterns: Dict[Pattern, Callable] ) -> Dict[str, Tuple[Node, FuseHandler]]: modules = dict(root.named_modules()) match_map : Dict[str, Tuple[Node, FuseHandler]] = {} # node name -> (root_node, match_value) def apply_match(pattern, node, match): if isinstance(pattern, tuple): s, *args = pattern apply_match(s, node, match) for subpattern, arg in zip(args, node.args): apply_match(subpattern, arg, match) else: # the first pattern matches will take precedence if node.name not in match_map: match_map[node.name] = match for node in reversed(graph.nodes): if node.name not in match_map: for pattern, value in patterns.items(): if is_match(modules, node, pattern): apply_match(pattern, node, (node, value(self, node))) return match_map
def insert_observer( node: Node, observer: torch.quantization.ObserverBase, model: torch.nn.Module, activation_post_process_map: Dict[str, torch.quantization.ObserverBase], env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable, observed_node_names_set: Set[str]): """Insert observer for node by modifying the observed_graph and attach observer module to the model Args: node: Node observer: observer/fake_quantize module instance """ # respect device affinity when adding observers model_device = assert_and_get_unique_device(model) if model_device: observer.to(model_device) # add observer module as attribute prefix = node.name + '_activation_post_process_' get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) # put observer instance activation_post_process map assert activation_post_process_map is not None activation_post_process_map[node.name] = observer # insert observer call env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name)
def replace_observer_with_quantize_dequantize_node( model: torch.nn.Module, graph: Graph, node: Node, modules: Dict[str, torch.nn.Module], node_name_to_scope: Dict[str, Tuple[str, type]], qconfig_map: Dict[str, QConfigAny]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) module_path, prefix = get_module_path_and_prefix( node, node_name_to_scope, qconfig_map) observer_module = modules[node.target] maybe_quantize_node_info = get_quantize_node_info(observer_module) # Skip replacing observers to quant/dequant nodes if the qconfigs of all # consumers and producers of this observer are None skip_replacement = all([ has_none_qconfig(n, qconfig_map) for n in list(node.args) + list(node.users.keys()) ]) if skip_replacement or maybe_quantize_node_info is None: # didn't find correponding quantize op and info for the observer_module # so we just remove the observer with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) else: # otherwise, we can convert the observer moduel call to quantize/dequantize node node_type, quantize_op, qparams = maybe_quantize_node_info # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): # TODO: we can add the information of whether a value needs to # be registered as an attribute in qparams dict itself if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( model, graph, module_path + prefix + key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node, )) node.replace_all_uses_with(dequantized_node) graph.erase_node(node)
def _fold_weight(self, quantized): packed_weights = dict() # map from folded node name to the prepacked weight name folded_nodes = dict() # get packed weights for node in quantized.graph.nodes: if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS: nodes_to_fold = collect_producer_nodes(node) if nodes_to_fold is not None: for node_to_fold in nodes_to_fold: folded_nodes[node_to_fold.name] = node prepacking_module = graph_module_from_producer_nodes( quantized, nodes_to_fold) packed_weight = prepacking_module() packed_weights[node.name] = packed_weight # remove folded nodes and replace the prepacking node with getattr folded_graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) get_new_packed_weight_name = get_new_attr_name_with_prefix( '_fx_pass_packed_weight_') quantized_root = quantized quantized_graph = quantized.graph for node in quantized_graph.nodes: prepack_node = folded_nodes.get(node.name, None) if prepack_node is node: packed_weight = packed_weights[node.name] # add a prepacked attribute to root packed_weight_name = get_new_packed_weight_name(quantized_root) setattr(quantized_root, packed_weight_name, packed_weight) # replace prepack node with a getattr node env[node.name] = folded_graph.create_node( 'get_attr', packed_weight_name, (), {}) elif prepack_node is not None: # remove the foled node continue else: # copy other nodes env[node.name] = folded_graph.node_copy(node, load_arg) quantized = GraphModule(quantized_root, folded_graph) return quantized
def convert(self, node: Node, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module], quantized_graph: Graph, node_name_to_scope: Dict[str, Tuple[str, type]], load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: if not self.all_node_args_are_tensors: return NotImplemented assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \ 'call_function are handled in DefaultNode' assert is_reference if convert_custom_config_dict is None: convert_custom_config_dict = {} additional_static_quant_mapping = convert_custom_config_dict.get( "static", {}) dtypes = get_qconfig_dtypes(qconfig) # We can produce reference for a dtypes including # (torch.quint8, torch.qint8, torch.qint32, torch.float16) act_dtype = activation_dtype(qconfig) if act_dtype == torch.float: op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float)) return op_out else: activation_post_process = \ self._maybe_get_last_node_only_observer(modules) assert activation_post_process is not None # make sure the input is quantized to act_dtype load_arg(quantized={0: act_dtype})(node.args) args = load_arg(quantized=torch.float)(node.args) kwargs = load_arg(quantized=torch.float)(node.kwargs) op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float)) return quantize_node(op_out, activation_post_process, node, modules, quantized_graph, node_name_to_scope, is_input=False)
def create_node_from_old_node_preserve_meta( quantized_graph: Graph, create_node_args: Tuple[Any, ...], old_node: Node, ) -> Node: """ Creates `new_node` and copies the necessary metadata to it from `old_node`. """ new_node = quantized_graph.create_node(*create_node_args) new_node.stack_trace = old_node.stack_trace return new_node
def create_getattr_from_value(module: GraphModule, graph: Graph, prefix: str, value: Any) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and registers the value as a buffer to the module. """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) module.register_buffer(attr_name, torch.tensor(value)) # Create get_attr with value attr_node = graph.create_node("get_attr", attr_name) return attr_node
def replace_target_nodes_with( fx_module: GraphModule, old_op: str, old_target: Target, new_op: str, new_target: Target, ): """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, and updates them to match the new op code and target""" new_graph = Graph() val_map : Dict[Node, Node] = {} for node in fx_module.graph.nodes: if node.op == old_op and node.target == old_target: args = map_arg(node.args, lambda n: val_map[n]) kwargs = map_arg(node.kwargs, lambda n: val_map[n]) assert isinstance(args, tuple) assert isinstance(kwargs, dict) val_map[node] = new_graph.create_node(new_op, new_target, args, kwargs, node.name) else: val_map[node] = new_graph.node_copy(node, lambda n : val_map[n]) fx_module.graph = new_graph
def graph_module_from_producer_nodes(root, producer_nodes): r''' Construct a graph module from extracted producer nodes from `collect_producer_nodes` function Args: root: the root module for the original graph producer_nodes: a list of nodes we use to construct the graph Return: A graph module constructed from the producer nodes ''' assert len(producer_nodes) > 0, 'list of producer nodes can not be empty' # since we traced back from node to getattrr producer_nodes.reverse() graph = Graph() env = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) for producer_node in producer_nodes: env[producer_node.name] = graph.node_copy(producer_node, load_arg) graph.output(load_arg(producer_nodes[-1].name)) graph_module = GraphModule(root, graph) return graph_module
def replace_observer_with_quantize_dequantize_node( graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None: """ Replace activation_post_process module call node with quantize and dequantize node Before: ... -> observer_0(x) -> ... After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ assert modules is not None assert isinstance(node.target, str) observer_module = modules[node.target] root_module = modules[""] if observer_module.dtype == torch.float32: # remove the node for now # TODO: support dynamic quant with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) elif observer_module.dtype in [ torch.quint8, torch.qint8, torch.float16 ]: node_type, quantize_op, qparams = get_quantize_node_info( observer_module) # replace observer node with quant - dequant node with graph.inserting_before(node): input_node = node.args[0] inputs = [input_node] for key, value in qparams.items(): if key in ['_scale_', '_zero_point_']: # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( root_module, graph, key, value) inputs.append(qparam_node) else: # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. inputs.append(value) quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node, )) node.replace_all_uses_with(dequantized_node) graph.erase_node(node)
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and registers the value as a buffer to the module. """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) device = assert_and_get_unique_device(module) new_value = value.clone().detach() if isinstance(value, torch.Tensor) \ else torch.tensor(value, device=device) module.register_buffer(attr_name, new_value) # Create get_attr with value attr_node = graph.create_node("get_attr", attr_name) return attr_node
def remove_observers_add_loggers( gm: GraphModule, node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]], logger_cls: Callable, model_name: str, ) -> GraphModule: """ Takes the graph of gm, removes all observers, adds loggers to the output of each node in nodes_to_instrument. Returns a GraphModule with the new graph. """ new_graph = Graph() env: Dict[str, Any] = {} modules = dict(gm.named_modules()) def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in gm.graph.nodes: if node.op == 'output': new_graph.output(map_arg(node.args[0], load_arg)) continue if node.op == 'call_module' and is_activation_post_process( modules[node.target]): # remove activation post process node env[node.name] = env[node.args[0].name] elif node in node_to_instrument_to_ref_node_name: other_node_name = node_to_instrument_to_ref_node_name[node] # ensure env is populated with base node env[node.name] = new_graph.node_copy(node, load_arg) # add the logger after the base node env[node.name] = _insert_logger_after_node(env[node.name], gm, logger_cls, '_ns_logger_', model_name, other_node_name) else: env[node.name] = new_graph.node_copy(node, load_arg) new_gm = GraphModule(gm, new_graph) return new_gm
def convert(self, observed, inplace=False, debug=False): assert self.activation_post_process_map is not None # move to cpu since we only have quantized cpu kernels observed.eval().cpu() observed_root = observed.root observed_graph = observed.graph if not inplace: observed_root = copy.deepcopy(observed_root) self.modules = dict(observed_root.named_modules()) matches = self._find_matches(observed.graph, self.modules, self.patterns) quants = self._find_quants(observed.graph, matches) self.quantized_graph = Graph() env = {} quant_env = {} def load_non_quantized(n): if n.name not in env: assert n.name in quant_env, \ 'trying to load float node but did not find node:' + n.name + \ ' in quantized environment:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] def load_quantized(n): if n.name not in quant_env: assert n.name in env, \ 'trying to load quantized node but did not find node:' + n.name + \ ' in float environment:' + str(env) assert n.name in quants, 'did not find quant object for node:' + n.name quant = quants[n.name][0] quant_env[n.name] = quant.convert(self, env[n.name]) return quant_env[n.name] def load_x(n): assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either of the environment' if n.name in quant_env: return quant_env[n.name] else: return env[n.name] def load_arg(quantized): """ if quantized is a list, then arg should be a list and the args with corresponding indexes will be quantized if quantized is a boolean, then all args will be quantized/not quantized if quantized is None, then we'll load the node as long as it exists """ assert quantized is None or isinstance( quantized, (tuple, list, bool)), type(quantized) def load_arg_impl(arg): if quantized is None: return map_arg(arg, load_x) if isinstance(quantized, bool): return map_arg( arg, load_quantized if quantized else load_non_quantized) elif isinstance(quantized, (tuple, list)): assert isinstance(arg, (tuple, list)), arg loaded_arg = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg): if i in quantized: loaded_arg.append(map_arg(a, load_quantized)) else: loaded_arg.append(map_arg(a, load_non_quantized)) return type(arg)(loaded_arg) return load_arg_impl def is_quantized(node): if isinstance(node, Node): assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' # there might be nodes appearing in both environemnts, but quant_env will take # precedence if node.name in quant_env: return True elif node.name in env: return False elif isinstance(node, list): quantized = map(is_quantized, node) if all(quantized): return True elif not any(quantized): return False else: raise Exception( "partially quantized inputs in list not handled yet") for node in observed_graph.nodes: root_node, matched, obj, qconfig = matches.get( node.name, (None, None, None, None)) if root_node is node: result = obj.convert(self, node, load_arg) quantized = True # Need to get correct quantized/non-quantized state for the output of CopyNode if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) if self.quant_type == QuantType.DYNAMIC: quantized = False if quantized: quant_env[node.name] = result else: env[node.name] = result continue elif root_node is not None: continue # handle activation post process calls if node.op == 'call_module': if node.target.split('.')[-1].startswith( 'activation_post_process_'): observer_module = self.modules[node.target] prev_node = node.args[0] if prev_node.name in quant_env: # if previous node is already quantized, we'll just remove the activation_post_process quant_env[node.name] = quant_env[prev_node.name] continue # replace activation post process with quantization ops parent_name = '' scale, zero_point = observer_module.calculate_qparams() # TODO: per channel scale = float(scale) zero_point = int(zero_point) dtype = observer_module.dtype qparams = { '_scale_': scale, '_zero_point_': zero_point, '_dtype_': dtype } i = 0 def noattr(module, qparams, i): for name in qparams.keys(): if hasattr(module, name + str(i)): return False return True def get_next_i(module, qparams): i = 0 while not noattr(module, qparams, i): i += 1 return i parent_module = self.modules[parent_name] i = get_next_i(parent_module, qparams) inputs = [load_non_quantized(node.args[0])] for key, value in qparams.items(): setattr(parent_module, key + str(i), value) qparam_full_path = key + str(i) if parent_name: qparam_full_path = parent_name + '.' + qparam_full_path inputs.append( self.quantized_graph.create_node( 'get_param', qparam_full_path)) quant_env[node.name] = self.quantized_graph.create_node( 'call_function', torch.quantize_per_tensor, inputs, {}) continue # dequantize inputs for the node that are not quantized env[node.name] = self.quantized_graph.node_copy( node, load_non_quantized) self.quantized_graph.output(load_non_quantized(observed_graph.result)) to_be_removed = [] for name, _ in observed_root.named_modules(): if name.split('.')[-1].startswith('activation_post_process_'): to_be_removed.append(name) for n in to_be_removed: delattr(observed_root, n) return GraphModule(observed_root, self.quantized_graph)